def _create_variables(self):
if self.input_type.dtype != 'float32':
raise TypeError('FC input dtype must be float32: %s' %
self.input_type.dtype)
if self.input_type.ndim != 1:
raise TypeError('FC input shape must be 1D: %s' %
str(self.input_type.shape))
self._bias = tf.get_variable(
'bias', self.output_type.shape, initializer=tf.constant_initializer(0))
self._weights = tf.get_variable(
'weights', [self.input_type.shape[0], self.output_type.shape[0]],
initializer=self._initializer)
if self._weight_norm:
self._scales = tf.get_variable(
'scales',
[self.output_type.shape[0]],
initializer=tf.constant_initializer(1.0))
评论列表
文章目录