python类int16()的实例源码

utils.py 文件源码 项目:zhusuan 作者: thu-ml 项目源码 文件源码 阅读 25 收藏 0 点赞 0 评论 0
def test_dtype_1parameter_discrete(test_class, Distribution):
    def _test_sample_dtype(input_, result_dtype, dtype):
        distribution = Distribution(input_, dtype=dtype)
        samples = distribution.sample(2)
        test_class.assertEqual(distribution.dtype, result_dtype)
        test_class.assertEqual(samples.dtype, result_dtype)

    for input_ in [[1.], [[2., 3.], [4., 5.]]]:
        _test_sample_dtype(input_, tf.int32, None)
        _test_sample_dtype(input_, tf.int16, tf.int16)
        _test_sample_dtype(input_, tf.int32, tf.int32)
        _test_sample_dtype(input_, tf.float32, tf.float32)
        _test_sample_dtype(input_, tf.float64, tf.float64)

    def _test_parameter_dtype_raise(param_dtype):
        param = tf.placeholder(param_dtype, [1])
        with test_class.assertRaises(TypeError):
            Distribution(param)

    _test_parameter_dtype_raise(tf.int32)
    _test_parameter_dtype_raise(tf.int64)

    # test dtype for log_prob and prob
    def _test_log_prob_dtype(param_dtype, given_dtype):
        param = tf.placeholder(param_dtype, [1])
        distribution = Distribution(param, dtype=given_dtype)
        test_class.assertEqual(distribution.param_dtype, param_dtype)

        # test for tensor
        given = tf.placeholder(given_dtype, None)
        prob = distribution.prob(given)
        log_prob = distribution.log_prob(given)

        test_class.assertEqual(prob.dtype, param_dtype)
        test_class.assertEqual(log_prob.dtype, param_dtype)

        # test for numpy
        given_np = np.array([1], given_dtype.as_numpy_dtype)
        prob_np = distribution.prob(given_np)
        log_prob_np = distribution.log_prob(given_np)

        test_class.assertEqual(prob_np.dtype, param_dtype)
        test_class.assertEqual(log_prob_np.dtype, param_dtype)

    _test_log_prob_dtype(tf.float16, tf.int32)
    _test_log_prob_dtype(tf.float32, tf.int32)
    _test_log_prob_dtype(tf.float64, tf.int64)
    _test_log_prob_dtype(tf.float32, tf.float32)
    _test_log_prob_dtype(tf.float32, tf.float64)
network_sparse.py 文件源码 项目:pruning_with_tensorflow 作者: ex4sperans 项目源码 文件源码 阅读 22 收藏 0 点赞 0 评论 0
def _build_network(self,
                       inputs: tf.Tensor,
                       sparse_layers: list,
                       activation_fn: callable) -> tf.Tensor:

        with tf.variable_scope('network'):

            net = inputs

            self.weight_tensors = []

            bias_initializer = tf.constant_initializer(0.1)

            for i, layer in enumerate(sparse_layers):

                with tf.variable_scope('layer_{layer}'.format(layer=i+1)):

                    # create variables based on sparse values                    
                    with tf.variable_scope('sparse'):

                        indicies = tf.get_variable(name='indicies',
                                                   initializer=layer.indices,
                                                   dtype=tf.int16)

                        values = tf.get_variable(name='values',
                                                 initializer=layer.values,
                                                 dtype=tf.float32)

                        dense_shape = tf.get_variable(name='dense_shape',
                                                      initializer=layer.dense_shape,
                                                      dtype=tf.int64)

                    # create a weight tensor based on the created variables
                    weights = tf.sparse_to_dense(tf.cast(indicies, tf.int64),
                                                 dense_shape,
                                                 values)

                    self.weight_tensors.append(weights)

                    name = 'bias'
                    bias = tf.get_variable(name=name,
                                           initializer=layer.bias)

                    net = tf.matmul(net, weights) + bias

                    if i < len(sparse_layers) - 1:
                        net = activation_fn(net)

            return net


问题


面经


文章

微信
公众号

扫码关注公众号