def uni_net_dynamic(cell, inputs, proj_dim=None, init_state=None, scope='uni_net_d0'):
# transpose to time major
inputs_tm = tf.transpose(inputs, [1,0,2], name='inputs_tm')
# infer timesteps and batch_size
timesteps, batch_size, _ = tf.unstack(tf.shape(inputs_tm))
# check if init_state is provided
# TODO : fix and add this
# init_state = init_state if init_state else cell.zero_state(batch_size,tf.float32)
if init_state is None:
init_state = cell.zero_state(batch_size, tf.float32)
states = tf.TensorArray(dtype=tf.float32, size=timesteps+1, name='states',
clear_after_read=False)
outputs = tf.TensorArray(dtype=tf.float32, size=timesteps, name='outputs',
clear_after_read=False)
def step(i, states, outputs):
# run one step
# read from TensorArray (states)
state_prev = states.read(i)
if is_lstm(cell):
# previous state <tensor> -> <LSTMStateTuple>
c, h = tf.unstack(state_prev)
state_prev = rnn.LSTMStateTuple(c,h)
output, state = cell(inputs_tm[i], state_prev)
# add state, output to list
states = states.write(i+1, state)
outputs = outputs.write(i, output)
i = tf.add(i,1)
return i, states, outputs
with tf.variable_scope(scope):
# initial state
states = states.write(0, init_state)
i = tf.constant(0)
# stopping condition
c = lambda x, y, z : tf.less(x, timesteps)
# body
b = lambda x, y, z : step(x, y, z)
# execution
_, fstates, foutputs = tf.while_loop(c,b, [i, states, outputs])
# if LSTM, project states
if is_lstm(cell):
d1 = 2*cell.state_size.c
d2 = proj_dim if proj_dim else d1//2
return foutputs.stack(), project_lstm_states(fstates.stack()[1:], d1, d2)
return foutputs.stack(), fstates.stack()[1:]
评论列表
文章目录