def decode(self, cell, init_state, loop_function=None):
outputs = []
prev = None
state = init_state
for i, inp in enumerate(self.decoder_inputs_emb):
if loop_function is not None and prev is not None:
with tf.variable_scope("loop_function", reuse=True):
inp = loop_function(prev, i)
if i > 0:
tf.get_variable_scope().reuse_variables()
output, state = cell(inp, state)
# print output.eval()
outputs.append(output)
if loop_function is not None:
prev = output
return outputs
评论列表
文章目录