def linear_combine(clen, pclen, idx):
Wl = param.get('Wl')
Wr = param.get('Wr')
dim = tf.unstack(tf.shape(Wl))[0]
batch_shape = tf.shape(clen)
f = (clen / pclen)
l = (pclen - idx - 1) / (pclen - 1)
r = (idx) / (pclen - 1)
# when pclen == 1, replace nan items with 0.5
l = tf.where(tf.is_nan(l), tf.ones_like(l) * 0.5, l)
r = tf.where(tf.is_nan(r), tf.ones_like(r) * 0.5, r)
lb = tf.transpose(tf.transpose(tf.eye(dim, batch_shape=batch_shape)) * l)
rb = tf.transpose(tf.transpose(tf.eye(dim, batch_shape=batch_shape)) * r)
fb = tf.transpose(tf.transpose(tf.eye(dim, batch_shape=batch_shape)) * f)
lb = tf.reshape(lb, [-1, hyper.word_dim])
rb = tf.reshape(rb, [-1, hyper.word_dim])
tmp = tf.matmul(lb, Wl) + tf.matmul(rb, Wr)
tmp = tf.reshape(tmp, [-1, hyper.word_dim, hyper.word_dim])
return tf.matmul(fb, tmp)
评论列表
文章目录