def set_input_shape(self, input_shape):
batch_size, rows, cols, input_channels = input_shape
kernel_shape = tuple(self.kernel_shape) + (input_channels,
self.output_channels)
assert len(kernel_shape) == 4
assert all(isinstance(e, int) for e in kernel_shape), kernel_shape
init = tf.random_normal(kernel_shape, dtype=tf.float32)
init = init / tf.sqrt(1e-7 + tf.reduce_sum(tf.square(init),
axis=(0, 1, 2)))
self.kernels = tf.Variable(init)
self.b = tf.Variable(
np.zeros((self.output_channels,)).astype('float32'))
input_shape = list(input_shape)
input_shape[0] = 1
dummy_batch = tf.zeros(input_shape)
dummy_output = self.fprop(dummy_batch)
output_shape = [int(e) for e in dummy_output.get_shape()]
output_shape[0] = 1
self.output_shape = tuple(output_shape)
评论列表
文章目录