def ntn(name, lhs, rhs, nr_output_channels,
use_bias=True, nonlin=__default_nonlin__,
W=None, b=None, param_dtype=__default_dtype__):
lhs, rhs= map(O.flatten2, [lhs, rhs])
assert lhs.static_shape[1] is not None and rhs.static_shape[1] is not None
W_shape = (lhs.static_shape[1], nr_output_channels, rhs.static_shape[1])
b_shape = (nr_output_channels, )
if W is None:
W = tf.contrib.layers.xavier_initializer()
W = O.ensure_variable('W', W, shape=W_shape, dtype=param_dtype)
if use_bias:
if b is None:
b = tf.constant_initializer()
b = O.ensure_variable('b', b, shape=b_shape, dtype=param_dtype)
out = tf.einsum('ia,abc,ic->ib', lhs.tft, W.tft, rhs.tft)
if use_bias:
out = tf.identity(out + b.add_axis(0), name='bias')
out = nonlin(out, name='nonlin')
return tf.identity(out, name='out')
评论列表
文章目录