net.py 文件源码

python
阅读 33 收藏 0 点赞 0 评论 0

项目:l3 作者: jacobandreas 项目源码 文件源码
def _linear(t_in, n_out):
    v_w = tf.get_variable(
            "w",
            shape=(t_in.get_shape()[-1], n_out),
            initializer=tf.uniform_unit_scaling_initializer(
                factor=INIT_SCALE))
    v_b = tf.get_variable(
            "b",
            shape=n_out,
            initializer=tf.constant_initializer(0))
    if len(t_in.get_shape()) == 2:
        return tf.einsum("ij,jk->ik", t_in, v_w) + v_b
    elif len(t_in.get_shape()) == 3:
        return tf.einsum("ijk,kl->ijl", t_in, v_w) + v_b
    else:
        assert False
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号