def _linear(t_in, n_out):
v_w = tf.get_variable(
"w",
shape=(t_in.get_shape()[-1], n_out),
initializer=tf.uniform_unit_scaling_initializer(
factor=INIT_SCALE))
v_b = tf.get_variable(
"b",
shape=n_out,
initializer=tf.constant_initializer(0))
if len(t_in.get_shape()) == 2:
return tf.einsum("ij,jk->ik", t_in, v_w) + v_b
elif len(t_in.get_shape()) == 3:
return tf.einsum("ijk,kl->ijl", t_in, v_w) + v_b
else:
assert False
评论列表
文章目录