def tri_combined(idx, pclen, depth, max_depth):
"""TF function, input: idx, pclen, depth, max_depth as batch (1D Tensor)
Output: weight tensor (3D Tensor), first dim is batch
"""
Wconvt = param.get('Wconvt')
Wconvl = param.get('Wconvl')
Wconvr = param.get('Wconvr')
dim = tf.unstack(tf.shape(Wconvt))[0]
batch_shape = tf.shape(idx)
tmp = (idx - 1) / (pclen - 1)
# when pclen == 1, replace nan items with 0.5
tmp = tf.where(tf.is_nan(tmp), tf.ones_like(tmp) * 0.5, tmp)
t = (max_depth - depth) / max_depth
r = (1 - t) * tmp
l = (1 - t) * (1 - r)
lb = tf.transpose(tf.transpose(tf.eye(dim, batch_shape=batch_shape)) * l)
rb = tf.transpose(tf.transpose(tf.eye(dim, batch_shape=batch_shape)) * r)
tb = tf.transpose(tf.transpose(tf.eye(dim, batch_shape=batch_shape)) * t)
lb = tf.reshape(lb, [-1, dim])
rb = tf.reshape(rb, [-1, dim])
tb = tf.reshape(tb, [-1, dim])
tmp = tf.matmul(lb, Wconvl) + tf.matmul(rb, Wconvr) + tf.matmul(tb, Wconvt)
tmp = tf.reshape(tmp, [-1, hyper.word_dim, hyper.conv_dim])
return tmp
评论列表
文章目录