tensorbox.py 文件源码

python
阅读 25 收藏 0 点赞 0 评论 0

项目:cancer 作者: yancz1989 项目源码 文件源码
def build_lstm_inner(H, lstm_input):
  '''
  build lstm decoder
  '''
  lstm_cell = rnn_cell.BasicLSTMCell(H['lstm_size'], forget_bias=0.0, state_is_tuple=False)
  if H['num_lstm_layers'] > 1:
    lstm = rnn_cell.MultiRNNCell([lstm_cell] * H['num_lstm_layers'], state_is_tuple=False)
  else:
    lstm = lstm_cell

  batch_size = H['batch_size'] * H['grid_height'] * H['grid_width']
  state = tf.zeros([batch_size, lstm.state_size])

  outputs = []
  with tf.variable_scope('RNN', initializer=tf.random_uniform_initializer(-0.1, 0.1)):
    for time_step in range(H['rnn_len']):
      if time_step > 0: tf.get_variable_scope().reuse_variables()
      output, state = lstm(lstm_input, state)
      outputs.append(output)
  return outputs
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号