rnn_test.py 文件源码

python
阅读 29 收藏 0 点赞 0 评论 0

项目:tensorflow-deep-qa 作者: shuishen112 项目源码 文件源码
def char_rnn_model(features, target):
  """Character level recurrent neural network model to predict classes."""
  target = tf.one_hot(target, 15, 1, 0)
  byte_list = tf.one_hot(features, 256, 1, 0)
  byte_list = tf.unstack(byte_list, axis=1)

  cell = tf.contrib.rnn.GRUCell(HIDDEN_SIZE)
  _, encoding = tf.contrib.rnn.static_rnn(cell, byte_list, dtype=tf.float32)

  logits = tf.contrib.layers.fully_connected(encoding, 15, activation_fn=None)
  loss = tf.contrib.losses.softmax_cross_entropy(logits, target)

  train_op = tf.contrib.layers.optimize_loss(
      loss,
      tf.contrib.framework.get_global_step(),
      optimizer='Adam',
      learning_rate=0.01)

  return ({
      'class': tf.argmax(logits, 1),
      'prob': tf.nn.softmax(logits)
  }, loss, train_op)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号