def compute_budget_loss(model, loss, updated_states, cost_per_sample):
"""
Compute penalization term on the number of updated states (i.e. used samples)
"""
if using_skip_rnn(model):
return tf.reduce_mean(tf.reduce_sum(cost_per_sample * updated_states, 1), 0)
else:
return tf.zeros(loss.get_shape())
graph_definition.py 文件源码
python
阅读 29
收藏 0
点赞 0
评论 0
评论列表
文章目录