def get_double_linear_controller(size, bias, input_keep_prob=1.0, is_train=None):
def double_linear_controller(inputs, state, memory):
"""
:param inputs: [N, i]
:param state: [N, d]
:param memory: [N, M, m]
:return: [N, M]
"""
rank = len(memory.get_shape())
_memory_size = tf.shape(memory)[rank-2]
tiled_inputs = tf.tile(tf.expand_dims(inputs, 1), [1, _memory_size, 1])
if isinstance(state, tuple):
tiled_states = [tf.tile(tf.expand_dims(each, 1), [1, _memory_size, 1])
for each in state]
else:
tiled_states = [tf.tile(tf.expand_dims(state, 1), [1, _memory_size, 1])]
# [N, M, d]
in_ = tf.concat(2, [tiled_inputs] + tiled_states + [memory])
out = double_linear_logits(in_, size, bias, input_keep_prob=input_keep_prob,
is_train=is_train)
return out
return double_linear_controller
评论列表
文章目录