utils_subword_rnn.py 文件源码

python
阅读 23 收藏 0 点赞 0 评论 0

项目:vsmlib 作者: undertherain 项目源码 文件源码
def charRNN(self, context):  # input a list of word ids, output a list of word embeddings
        # if chainer.config.train:
        #     print("train")
        # else:
        #     print("test")
        contexts2charIds = self.index2charIds[context]

        #sorting the context_char, make sure array length in descending order
        # ref: https://docs.chainer.org/en/stable/reference/generated/chainer.links.LSTM.html?highlight=Variable-length
        context_char_length = np.array([len(t) for t in contexts2charIds])
        argsort = context_char_length.argsort()[::-1] # descending order
        argsort_reverse = np.zeros(len(argsort), dtype=np.int32)  # this is used to restore the original order
        for i in range(len(argsort)):
            argsort_reverse[argsort[i]] = i
        contexts2charIds = contexts2charIds[context_char_length.argsort()[::-1]]

        #transpose a 2D list/numpy array
        rnn_inputs = [[] for i in range(len(contexts2charIds[0]))]
        for j in range(len(contexts2charIds)) :
            for i in range(len(contexts2charIds[j])):
                rnn_inputs[i].append(contexts2charIds[j][i])

        self.reset_state()
        for i in range(len(rnn_inputs)):
            y_ = self(np.array(rnn_inputs[i], np.int32))
        y = self.out(self.mid.h)
        y = y[argsort_reverse] # restore the original order
        return y
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号