def mask_finished(finished, now_, prev_):
mask = tf.expand_dims(tf.to_float(finished), 1)
if isinstance(prev_, tuple):
# tuple states
next_ = []
for ns, s in zip(now_, prev_):
# fucking LSTMStateTuple
if isinstance(ns, LSTMStateTuple):
next_.append(
LSTMStateTuple(c=(1. - mask) * ns.c + mask * s.c,
h=(1. - mask) * ns.h + mask * s.h))
else:
next_.append((1. - mask) * ns + mask * s)
next_ = tuple(next_)
else:
next_ = (1. - mask) * now_ + mask * prev_
return next_
评论列表
文章目录