def call(self, step_inputs, state, scope=None, initialization='gaussian'):
"""
Make one step of ISAN transition.
Args:
step_inputs: one-hot encoded inputs, shape bs x n
state: previous hidden state, shape bs x d
scope: current scope
initialization: how to initialize the transition matrices:
orthogonal: usually speeds up training, orthogonalize Gaussian matrices
gaussian: sample gaussian matrices with a sensible scale
"""
d = self._num_units
n = step_inputs.shape[1].value
if initialization == 'orthogonal':
wx_ndd_init = np.zeros((n, d * d), dtype=np.float32)
for i in range(n):
wx_ndd_init[i, :] = orth(np.random.randn(d, d)).astype(np.float32).ravel()
wx_ndd_initializer = tf.constant_initializer(wx_ndd_init)
elif initialization == 'gaussian':
wx_ndd_initializer = tf.random_normal_initializer(stddev=1.0 / np.sqrt(d))
else:
raise Exception('Unknown init type: %s' % initialization)
wx_ndd = tf.get_variable('Wx', shape=[n, d * d],
initializer=wx_ndd_initializer)
bx_nd = tf.get_variable('bx', shape=[n, d],
initializer=tf.zeros_initializer())
# Multiplication with a 1-hot is just row selection.
# As of Jan '17 this is faster than doing gather.
Wx_bdd = tf.reshape(tf.matmul(step_inputs, wx_ndd), [-1, d, d])
bx_bd = tf.reshape(tf.matmul(step_inputs, bx_nd), [-1, 1, d])
# Reshape the state so that matmul multiplies different matrices
# for each batch element.
single_state = tf.reshape(state, [-1, 1, d])
new_state = tf.reshape(tf.matmul(single_state, Wx_bdd) + bx_bd, [-1, d])
return new_state, new_state
评论列表
文章目录