def lyr_linear(
name, s_x, odim,
axis=-1, bias=True, w_init=None, b_init=None):
'''
Like tf.xw_plus_b, but works on arbitrary shape
Args:
name: string
s_x: tensor variable
odim: integer
axis: integer
bias: boolean, whether to use bias
w_init: initializer for W
b_init: initializer for B
'''
assert isinstance(odim, int)
x_shape = s_x.get_shape().as_list()
idim = x_shape[axis]
ndim = len(x_shape)
assert -ndim <= axis < ndim
assert isinstance(idim, int)
with tf.variable_scope(name):
v_w = tf.get_variable(
'W', [idim, odim],
initializer=w_init,
dtype=hparams.FLOATX)
if ndim == 1:
s_y = tf.matmul(tf.expand_dims(s_x, 0), v_w)
s_y = tf.squeeze(s_y, 0)
elif ndim == 2:
if axis % 2 == 1:
s_y = tf.matmul(s_x, v_w)
else:
s_y = tf.matmul(tf.transpose(s_x), v_w)
s_y = tf.transpose(s_x)
elif (axis+1) % ndim == 0:
s_batch_shp = tf.shape(s_x)[:-1]
s_x = tf.reshape(
s_x,
[tf.reduce_prod(s_batch_shp, axis=None), x_shape[-1]])
s_y = tf.matmul(s_x, v_w)
s_y = tf.reshape(s_y, tf.concat([s_batch_shp, [odim]], axis=0))
else:
s_y = tf.tensordot(s_x, v_w, [[axis], [0]])
if bias:
if b_init is None:
b_init = tf.constant_initializer(0., dtype=hparams.FLOATX)
v_b = tf.get_variable(
'B', [odim],
initializer=b_init,
dtype=hparams.FLOATX)
s_b = tf.reshape(v_b, [odim] + [1] * (ndim - (axis % ndim) - 1))
s_y = s_y + s_b
return s_y
评论列表
文章目录