def get_rnn_init_state(x, cell):
"""
x: [batch, dim], must match the dim of the cell
"""
if isinstance(cell, tf.contrib.rnn.MultiRNNCell):
batch = x.get_shape()[0]
z = list(cell.zero_state(batch, dtype=tf.float32))
if isinstance(z[0], tuple):
z[0] = (tf.zeros_like(x), x)
else:
z[0] = x
return tuple(z)
if isinstance(cell.state_size, tuple):
#lstm cell
assert(len(cell.state_size) == 2)
return (tf.zeros_like(x), x)
# assume GRU Cell
return x
评论列表
文章目录