dbpedia.py 文件源码

python
阅读 30 收藏 0 点赞 0 评论 0

项目:GAN 作者: ilblackdragon 项目源码 文件源码
def autoencoder_model(feature, target, mode, params):
  """Autoencodes sequence model."""
  vocab_size = params.get('vocab_size')
  embed_dim = params.get('embed_dim')

  tf.identity(feature[0], name='feature')
  embed_feature = sequence.embed_features(
    feature, vocab_size=vocab_size, embed_dim=embed_dim)
  output, _ = sequence.sequence_autoencoder_discriminator(
    embed_feature, length=FLAGS.max_doc_length, hidden_size=embed_dim)
  logits, predictions = sequence.outbed_generated(output)

  # Loss and training.
  loss = tf.nn.sparse_softmax_cross_entropy_with_logits(logits, feature)
  loss = tf.reduce_mean(tf.reduce_sum(loss, axis=1))
  train_op = layers.optimize_loss(
      loss, tf.train.get_global_step(),
      learning_rate=params['learning_rate'],
      optimizer=params.get('optimizer', 'Adam'))
  return predictions, loss, train_op
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号