bluenet.py 文件源码

python
阅读 25 收藏 0 点赞 0 评论 0

项目:tensor2tensor 作者: tensorflow 项目源码 文件源码
def run_unary_modules_sample(modules, cur, hparams, k):
  """Run modules, sampling k."""
  selection_weights = create_selection_weights(
      "selection", ("softmax_topk", k),
      shape=[len(modules)],
      inv_t=100.0 * common_layers.inverse_exp_decay(
          hparams.anneal_until, min_value=0.01))
  all_res = [
      tf.cond(
          tf.less(selection_weights.normalized[n], 1e-6),
          lambda: tf.zeros_like(cur),
          lambda i=n: modules[i](cur, hparams)) for n in xrange(len(modules))
  ]
  all_res = tf.concat([tf.expand_dims(r, axis=0) for r in all_res], axis=0)
  res = all_res * tf.reshape(selection_weights.normalized, [-1, 1, 1, 1, 1])
  return tf.reduce_sum(res, axis=0)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号