ops.py 文件源码

python
阅读 27 收藏 0 点赞 0 评论 0

项目:DaNet-Tensorflow 作者: khaotik 项目源码 文件源码
def lyr_linear(
        name, s_x, odim,
        axis=-1, bias=True, w_init=None, b_init=None):
    '''
    Like tf.xw_plus_b, but works on arbitrary shape

    Args:
        name: string
        s_x: tensor variable
        odim: integer
        axis: integer
        bias: boolean, whether to use bias
        w_init: initializer for W
        b_init: initializer for B
    '''
    assert isinstance(odim, int)
    x_shape = s_x.get_shape().as_list()
    idim = x_shape[axis]
    ndim = len(x_shape)
    assert -ndim <= axis < ndim
    assert isinstance(idim, int)
    with tf.variable_scope(name):
        v_w = tf.get_variable(
            'W', [idim, odim],
            initializer=w_init,
            dtype=hparams.FLOATX)
        if ndim == 1:
            s_y = tf.matmul(tf.expand_dims(s_x, 0), v_w)
            s_y = tf.squeeze(s_y, 0)
        elif ndim == 2:
            if axis % 2 == 1:
                s_y = tf.matmul(s_x, v_w)
            else:
                s_y = tf.matmul(tf.transpose(s_x), v_w)
                s_y = tf.transpose(s_x)
        elif (axis+1) % ndim == 0:
            s_batch_shp = tf.shape(s_x)[:-1]
            s_x = tf.reshape(
                s_x,
                [tf.reduce_prod(s_batch_shp, axis=None), x_shape[-1]])
            s_y = tf.matmul(s_x, v_w)
            s_y = tf.reshape(s_y, tf.concat([s_batch_shp, [odim]], axis=0))
        else:
            s_y = tf.tensordot(s_x, v_w, [[axis], [0]])
        if bias:
            if b_init is None:
                b_init = tf.constant_initializer(0., dtype=hparams.FLOATX)
            v_b = tf.get_variable(
                    'B', [odim],
                    initializer=b_init,
                    dtype=hparams.FLOATX)
            s_b = tf.reshape(v_b, [odim] + [1] * (ndim - (axis % ndim) - 1))
            s_y = s_y + s_b
    return s_y
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号