embedding.py 文件源码

python
阅读 31 收藏 0 点赞 0 评论 0

项目:tensorflow-tbcnn 作者: Aetf 项目源码 文件源码
def linear_combine(clen, pclen, idx):
    Wl = param.get('Wl')
    Wr = param.get('Wr')

    dim = tf.unstack(tf.shape(Wl))[0]
    batch_shape = tf.shape(clen)

    f = (clen / pclen)
    l = (pclen - idx - 1) / (pclen - 1)
    r = (idx) / (pclen - 1)
    # when pclen == 1, replace nan items with 0.5
    l = tf.where(tf.is_nan(l), tf.ones_like(l) * 0.5, l)
    r = tf.where(tf.is_nan(r), tf.ones_like(r) * 0.5, r)

    lb = tf.transpose(tf.transpose(tf.eye(dim, batch_shape=batch_shape)) * l)
    rb = tf.transpose(tf.transpose(tf.eye(dim, batch_shape=batch_shape)) * r)
    fb = tf.transpose(tf.transpose(tf.eye(dim, batch_shape=batch_shape)) * f)

    lb = tf.reshape(lb, [-1, hyper.word_dim])
    rb = tf.reshape(rb, [-1, hyper.word_dim])

    tmp = tf.matmul(lb, Wl) + tf.matmul(rb, Wr)

    tmp = tf.reshape(tmp, [-1, hyper.word_dim, hyper.word_dim])

    return tf.matmul(fb, tmp)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号