cnn.py 文件源码

python
阅读 39 收藏 0 点赞 0 评论 0

项目:TensorArtist 作者: vacancy 项目源码 文件源码
def ntn(name, lhs, rhs, nr_output_channels,
        use_bias=True, nonlin=__default_nonlin__,
        W=None, b=None, param_dtype=__default_dtype__):

    lhs, rhs= map(O.flatten2, [lhs, rhs])

    assert lhs.static_shape[1] is not None and rhs.static_shape[1] is not None
    W_shape = (lhs.static_shape[1], nr_output_channels, rhs.static_shape[1])
    b_shape = (nr_output_channels, )

    if W is None:
        W = tf.contrib.layers.xavier_initializer()
    W = O.ensure_variable('W', W, shape=W_shape, dtype=param_dtype)
    if use_bias:
        if b is None:
            b = tf.constant_initializer()
        b = O.ensure_variable('b', b, shape=b_shape, dtype=param_dtype)

    out = tf.einsum('ia,abc,ic->ib', lhs.tft, W.tft, rhs.tft)
    if use_bias:
        out = tf.identity(out + b.add_axis(0), name='bias')

    out = nonlin(out, name='nonlin')
    return tf.identity(out, name='out')
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号