def prepareDecoder(self, encInfo):
self.model.decLSTM.reset_state()
if self.attn_mode == 0:
aList = None
elif self.attn_mode == 1:
aList = encInfo.attnList
elif self.attn_mode == 2:
aList = self.model.attnM(
chaFunc.reshape(encInfo.attnList,
(encInfo.cMBSize * encInfo.encLen, self.hDim)))
# TODO: ???????encoder???????
else:
assert 0, "ERROR"
xp = cuda.get_array_module(encInfo.lstmVars[0].data)
finalHS = chainer.Variable(
xp.zeros(
encInfo.lstmVars[0].data.shape,
dtype=xp.float32)) # ???input_feed?0????
return aList, finalHS
############################
评论列表
文章目录