model_utils.py 文件源码

python
阅读 27 收藏 0 点赞 0 评论 0

项目:sidenet 作者: shashiongithub 项目源码 文件源码
def simple_rnn(rnn_input, initial_state=None):
  """Implements Simple RNN
  Args:
  rnn_input: List of tensors of sizes [-1, sentembed_size]
  Returns:
  encoder_outputs, encoder_state
  """     
  # Setup cell
  cell_enc = get_lstm_cell()

  # Setup RNNs
  dtype = tf.float16 if FLAGS.use_fp16 else tf.float32
  rnn_outputs, rnn_state = tf.nn.rnn(cell_enc, rnn_input, dtype=dtype, initial_state=initial_state)
  # print(rnn_outputs)
  # print(rnn_state)

  return rnn_outputs, rnn_state
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号