def encodeSentenceFWD(self, train_mode, sentence, args, dropout_rate):
if args.gpu_enc != args.gpu_dec: # enc?dec??GPU???
chainer.cuda.get_device(args.gpu_enc).use()
encLen = len(sentence) # ??
cMBSize = len(sentence[0]) # minibatch size
# ?????embedding??? ??????????
encEmbList = self.getEncoderInputEmbeddings(sentence, args)
flag_train = (train_mode > 0)
lstmVars = [0] * self.n_layers * 2
if self.flag_merge_encfwbw == 0: # fw?bw??????????????
hyf, cyf, fwHout = self.model.encLSTM_f(
None, None, encEmbList, flag_train, args) # ???
hyb, cyb, bkHout = self.model.encLSTM_b(
None, None, encEmbList[::-1], flag_train, args) # ???
for z in six.moves.range(self.n_layers):
lstmVars[2 * z] = cyf[z] + cyb[z]
lstmVars[2 * z + 1] = hyf[z] + hyb[z]
elif self.flag_merge_encfwbw == 1: # fw?bw????????
sp = (cMBSize, self.hDim)
for z in six.moves.range(self.n_layers):
if z == 0: # ??? embedding???
biH = encEmbList
else: # ????? ????????
# ????????bkHout????????????
biH = fwHout + bkHout[::-1]
# z?????
hyf, cyf, fwHout = self.model.encLSTM_f(
z, biH, flag_train, dropout_rate, args)
# z??????
hyb, cyb, bkHout = self.model.encLSTM_b(
z, biH[::-1], flag_train, dropout_rate, args)
# ??????????????????????????
# ???????
lstmVars[2 * z] = chaFunc.reshape(cyf + cyb, sp)
lstmVars[2 * z + 1] = chaFunc.reshape(hyf + hyb, sp)
else:
assert 0, "ERROR"
# ?????
if self.flag_enc_boseos == 0: # default
# fwHout?[:,]???????????
biHiddenStack = fwHout[:, ] + bkHout[::-1]
elif self.flag_enc_boseos == 1:
bkHout2 = bkHout[::-1] # ?????
biHiddenStack = fwHout[1:encLen - 1, ] + bkHout2[1:encLen - 1, ]
# BOS, EOS?????? TODO ??????0??????????
encLen -= 2
else:
assert 0, "ERROR"
# (enc????, minibatch??, ??????)
# => (minibatch??, enc????, ??????)???
biHiddenStackSW01 = chaFunc.swapaxes(biHiddenStack, 0, 1)
# ?LSTM???????????decoder?LSTM????????
lstmVars = chaFunc.stack(lstmVars)
# encoder????encInfoObject???????
retO = self.encInfoObject(biHiddenStackSW01, lstmVars, encLen, cMBSize)
return retO
评论列表
文章目录