model.py 文件源码

python
阅读 23 收藏 0 点赞 0 评论 0

项目:seqgan-text-tensorflow 作者: codekansas 项目源码 文件源码
def get_weight(self, name, shape,
                   init='glorot',
                   device='gpu',
                   weight_val=None,
                   trainable=True):
        """Creates a new weight.
        Args:
            name: str, the name of the variable.
            shape: tuple of ints, the shape of the variable.
            init: str, the type of initialize to use.
            device: str, 'cpu' or 'gpu'.
            weight_val: Numpy array to use as the initial weights.
            trainable: bool, whether or not this weight is trainable.
        Returns:
            a trainable TF variable with shape `shape`.
        """

        if weight_val is None:
            init = init.lower()
            if init == 'normal':
                initializer = (lambda shape, dtype, partition_info:
                               tf.random_normal(shape, stddev=0.05))
            elif init == 'uniform':
                initializer = (lambda shape, dtype, partition_info:
                               tf.random_uniform(shape, stddev=0.05))
            elif init == 'glorot':
                initializer = (lambda shape, dtype, partition_info:
                               tf.random_normal(
                                   shape, stddev=np.sqrt(6. / sum(shape))))
            elif init == 'eye':
                assert all(i == shape[0] for i in shape)
                initializer = (lambda shape, dtype, partition_info:
                               tf.eye(shape[0]))
            elif init == 'zero':
                initializer = (lambda shape, dtype, partition_info:
                               tf.zeros(shape))
            else:
                raise ValueError('Invalid init: "%s"' % init)
        else:
            weight_val = weight_val.astype('float32')

        device = device.lower()
        if device == 'gpu':
            on_gpu = True
        elif device == 'cpu':
            on_gpu = False
        else:
            raise ValueError('Invalid device: "%s"' % device)

        if self._only_cpu:
            on_gpu = False

        with tf.device('/gpu:0' if on_gpu else '/cpu:0'):
            weight = tf.get_variable(name=name,
                                     shape=shape,
                                     initializer=initializer,
                                     trainable=trainable)
        self._weights.append(weight)

        return weight
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号