def match(qstates, pstates, d, dropout=None):
# infer batch_size, passage length and question length
qlen, batch_size, _ = tf.unstack(tf.shape(qstates))
plen = tf.shape(pstates)[0]
# ouput projection params
# Wo = tf.get_variable('Wo', shape=[2*d, d], dtype=tf.float32)
# define rnn cell
# TODO : replace with LSTM
cell = rcell('lstm', num_units=2*d, dropout=dropout)
states = tf.TensorArray(dtype=tf.float32, size=plen+1, name='states',
clear_after_read=False)
outputs = tf.TensorArray(dtype=tf.float32, size=plen, name='outputs',
clear_after_read=False)
# set init state
#init_state = tf.zeros(dtype=tf.float32, shape=[batch_size, 2*d])
init_state = cell.zero_state(batch_size, tf.float32)
states = states.write(0, init_state)
def mlstm_step(i, states, outputs):
# get previous state
prev_state = states.read(i)
prev_state = tf.unstack(prev_state)
prev_state_tuple = tf.contrib.rnn.LSTMStateTuple(prev_state[0], prev_state[1])
prev_state_c = prev_state_tuple.c
# get attention weighted representation
ci = attention(qstates, pstates[i], prev_state_c, d)
# combine ci and input(i)
input_ = tf.concat([pstates[i], ci], axis=-1)
output, state = cell(input_, prev_state_tuple)
# save output, state
states = states.write(i+1, state)
outputs = outputs.write(i, output)
return (i+1, states, outputs)
# execute loop
#i = tf.constant(0)
c = lambda x, y, z : tf.less(x, plen)
b = lambda x, y, z : mlstm_step(x, y, z)
_, fstates, foutputs = tf.while_loop(c,b, [0, states, outputs])
return foutputs.stack(), project_lstm_states(fstates.stack()[1:], 4*d, d)
评论列表
文章目录