def naive_decoder(cell, enc_states, targets, start_token, end_token,
feed_previous=True, training=True, scope='naive_decoder.0'):
init_state = enc_states[-1]
timesteps = tf.shape(enc_states)[0]
# targets time major
targets_tm = tf.transpose(targets, [1,0,2])
states = tf.TensorArray(dtype=tf.float32, size=timesteps+1, name='states',
clear_after_read=False)
outputs = tf.TensorArray(dtype=tf.float32, size=timesteps+1, name='outputs',
clear_after_read=False)
def step(i, states, outputs):
# run one step
# read from TensorArray (states)
state_prev = states.read(i)
if is_lstm(cell):
# previous state <tensor> -> <LSTMStateTuple>
c, h = tf.unstack(state_prev)
state_prev = rnn.LSTMStateTuple(c,h)
if feed_previous:
input_ = outputs.read(i)
else:
input_ = targets_tm[i]
output, state = cell(input_, state_prev)
# add state, output to list
states = states.write(i+1, state)
outputs = outputs.write(i+1, output)
i = tf.add(i,1)
return i, states, outputs
with tf.variable_scope(scope):
# initial state
states = states.write(0, init_state)
# initial input
outputs = outputs.write(0, start_token)
i = tf.constant(0)
# Stop loop condition
if training:
c = lambda x, y, z : tf.less(x, timesteps)
else:
c = lambda x, y, z : tf.reduce_all(tf.not_equal(tf.argmax(z.read(x), axis=-1),
end_token))
# body
b = lambda x, y, z : step(x, y, z)
# execution
_, fstates, foutputs = tf.while_loop(c,b, [i, states, outputs])
return foutputs.stack()[1:] # add states; but why?
评论列表
文章目录