def run_unary_modules_sample(modules, cur, hparams, k):
"""Run modules, sampling k."""
selection_weights = create_selection_weights(
"selection", ("softmax_topk", k),
shape=[len(modules)],
inv_t=100.0 * common_layers.inverse_exp_decay(
hparams.anneal_until, min_value=0.01))
all_res = [
tf.cond(
tf.less(selection_weights.normalized[n], 1e-6),
lambda: tf.zeros_like(cur),
lambda i=n: modules[i](cur, hparams)) for n in xrange(len(modules))
]
all_res = tf.concat([tf.expand_dims(r, axis=0) for r in all_res], axis=0)
res = all_res * tf.reshape(selection_weights.normalized, [-1, 1, 1, 1, 1])
return tf.reduce_sum(res, axis=0)
评论列表
文章目录