def __call__(self, inputs, state, scope=None):
batch_size = tf.shape(inputs)[0]
if self._apply_to == 'input':
inputs = slim.flatten(inputs) if self._shape == -1 else tf.reshape(inputs, [batch_size] + self._shape)
return self._cell(inputs, state)
elif self._apply_to == 'output':
output, res_state = self._cell(inputs, state)
output = slim.flatten(output) if self._shape == -1 else tf.reshape(output, [batch_size] + self._shape)
return output, res_state
elif self._apply_to == 'state':
output, res_state = self._cell(inputs, state)
res_state = slim.flatten(res_state) if self._shape == -1 else tf.reshape(res_state, [batch_size] + self._shape)
return output, res_state
else:
raise ValueError('Unknown apply_to: "{}"'.format(self._apply_to))
评论列表
文章目录