def charRNN(self, context): # input a list of word ids, output a list of word embeddings
# if chainer.config.train:
# print("train")
# else:
# print("test")
contexts2charIds = self.index2charIds[context]
#sorting the context_char, make sure array length in descending order
# ref: https://docs.chainer.org/en/stable/reference/generated/chainer.links.LSTM.html?highlight=Variable-length
context_char_length = np.array([len(t) for t in contexts2charIds])
argsort = context_char_length.argsort()[::-1] # descending order
argsort_reverse = np.zeros(len(argsort), dtype=np.int32) # this is used to restore the original order
for i in range(len(argsort)):
argsort_reverse[argsort[i]] = i
contexts2charIds = contexts2charIds[context_char_length.argsort()[::-1]]
#transpose a 2D list/numpy array
rnn_inputs = [[] for i in range(len(contexts2charIds[0]))]
for j in range(len(contexts2charIds)) :
for i in range(len(contexts2charIds[j])):
rnn_inputs[i].append(contexts2charIds[j][i])
self.reset_state()
for i in range(len(rnn_inputs)):
y_ = self(np.array(rnn_inputs[i], np.int32))
y = self.out(self.mid.h)
y = y[argsort_reverse] # restore the original order
return y
评论列表
文章目录