tbcnn.py 文件源码

python
阅读 25 收藏 0 点赞 0 评论 0

项目:tensorflow-tbcnn 作者: Aetf 项目源码 文件源码
def tri_combined(idx, pclen, depth, max_depth):
    """TF function, input: idx, pclen, depth, max_depth as batch (1D Tensor)
    Output: weight tensor (3D Tensor), first dim is batch
    """
    Wconvt = param.get('Wconvt')
    Wconvl = param.get('Wconvl')
    Wconvr = param.get('Wconvr')

    dim = tf.unstack(tf.shape(Wconvt))[0]
    batch_shape = tf.shape(idx)

    tmp = (idx - 1) / (pclen - 1)
    # when pclen == 1, replace nan items with 0.5
    tmp = tf.where(tf.is_nan(tmp), tf.ones_like(tmp) * 0.5, tmp)

    t = (max_depth - depth) / max_depth
    r = (1 - t) * tmp
    l = (1 - t) * (1 - r)

    lb = tf.transpose(tf.transpose(tf.eye(dim, batch_shape=batch_shape)) * l)
    rb = tf.transpose(tf.transpose(tf.eye(dim, batch_shape=batch_shape)) * r)
    tb = tf.transpose(tf.transpose(tf.eye(dim, batch_shape=batch_shape)) * t)

    lb = tf.reshape(lb, [-1, dim])
    rb = tf.reshape(rb, [-1, dim])
    tb = tf.reshape(tb, [-1, dim])

    tmp = tf.matmul(lb, Wconvl) + tf.matmul(rb, Wconvr) + tf.matmul(tb, Wconvt)

    tmp = tf.reshape(tmp, [-1, hyper.word_dim, hyper.conv_dim])
    return tmp
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号