def get_linear_controller(bias, input_keep_prob=1.0, is_train=None):
def linear_controller(inputs, state, memory):
rank = len(memory.get_shape())
_memory_size = tf.shape(memory)[rank-2]
tiled_inputs = tf.tile(tf.expand_dims(inputs, 1), [1, _memory_size, 1])
if isinstance(state, tuple):
tiled_states = [tf.tile(tf.expand_dims(each, 1), [1, _memory_size, 1])
for each in state]
else:
tiled_states = [tf.tile(tf.expand_dims(state, 1), [1, _memory_size, 1])]
# [N, M, d]
in_ = tf.concat(2, [tiled_inputs] + tiled_states + [memory])
out = linear(in_, 1, bias, squeeze=True, input_keep_prob=input_keep_prob, is_train=is_train)
return out
return linear_controller
评论列表
文章目录