model.py 文件源码

python
阅读 21 收藏 0 点赞 0 评论 0

项目:GAN 作者: ilblackdragon 项目源码 文件源码
def autoencoder_model(feature, target, mode, params):
  """Autoencodes features with given function."""
  autoencoder_fn = params.get('autoencoder_fn')
  feature_processor = params.get('feature_processor', lambda f: f)
  generated_postprocess = params.get('generated_postprocess', lambda f: f)

  # Process features.
  feature = feature_processor(feature)

  # Auto-encode.
  generated, _ = autoencoder_fn(feature)

  # Loss and training.
  loss = tf.contrib.losses.mean_squared_error(feature, generated)
  train_op = layers.optimize_loss(
      loss, tf.train.get_global_step(),
      learning_rate=params['learning_rate'],
      optimizer=params.get('optimizer', 'Adam'))

  # Post process generated.
  prediction = generated_postprocess(generated)
  prediction = tf.identity(prediction, name='generated')
  return prediction, loss, train_op
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号