def softmax_with_base(shape, base_untiled, x, mask=None, name='sig'):
if mask is not None:
x += VERY_SMALL_NUMBER * (1.0 - mask)
base_shape = shape[:-1] + [1]
for _ in shape:
base_untiled = tf.expand_dims(base_untiled, -1)
base = tf.tile(base_untiled, base_shape)
c_shape = shape[:-1] + [shape[-1] + 1]
c = tf.concat(len(shape)-1, [base, x])
c_flat = tf.reshape(c, [reduce(mul, shape[:-1], 1), c_shape[-1]])
p_flat = tf.nn.softmax(c_flat)
p_cat = tf.reshape(p_flat, c_shape)
s_aug = tf.slice(p_cat, [0 for _ in shape], [i for i in shape[:-1]] + [1])
s = tf.squeeze(s_aug, [len(shape)-1])
sig = tf.sub(1.0, s, name="sig")
p = tf.slice(p_cat, [0 for _ in shape[:-1]] + [1], shape)
return sig, p
评论列表
文章目录