def __call__(self, inputs, state, scope=None):
output, res_state = self._cell(inputs, state)
projected = None
with tf.variable_scope((scope or self._name)):
if self._spec['name'] == 'fc':
projected = slim.fully_connected(output, self._spec['size'], activation_fn=None)
elif self._spec['name'] == 't_conv':
projected = slim.layers.conv2d_transpose(output, self._spec['size'], self._spec['kernel'], self._spec['stride'], activation_fn=None)
elif self._spec['name'] == 'r_conv':
resized = tf.image.resize_images(output, (self._spec['stride'][0] * output.get_shape()[1].value,
self._spec['stride'][1] * output.get_shape()[2].value), method=1)
projected = slim.layers.conv2d(resized, self._spec['size'], self._spec['kernel'], activation_fn=None)
else:
raise ValueError('Unknown layer name "{}"'.format(self._spec['name']))
return projected, res_state
评论列表
文章目录