spatial_transformer.py 文件源码

python
阅读 25 收藏 0 点赞 0 评论 0

项目:tf_practice 作者: juho-lee 项目源码 文件源码
def to_loc(input, is_simple=False):
    if len(input.get_shape()) == 4:
        input = layers.flatten(input)
    num_inputs = input.get_shape()[1]
    num_outputs = 3 if is_simple else 6
    W_init = tf.constant_initializer(
            np.zeros((num_inputs, num_outputs)))
    if is_simple:
        b_init = tf.constant_initializer(np.array([1.,0.,0.]))
    else:
        b_init = tf.constant_initializer(np.array([1.,0.,0.,0.,1.,0.]))

    return layers.fully_connected(input, num_outputs,
            activation_fn=None,
            weights_initializer=W_init,
            biases_initializer=b_init)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号