def call(self, inputs, state):
"""The most basic URNN cell.
Args:
inputs (Tensor - batch_sz x num_in): One batch of cell input.
state (Tensor - batch_sz x num_units): Previous cell state: COMPLEX
Returns:
A tuple (outputs, state):
outputs (Tensor - batch_sz x num_units*2): Cell outputs on the whole batch.
state (Tensor - batch_sz x num_units): New state of the cell.
"""
#print("cell.call inputs:", inputs.shape, inputs.dtype)
#print("cell.call state:", state.shape, state.dtype)
# prepare input linear combination
inputs_mul = tf.matmul(inputs, tf.transpose(self.w_ih)) # [batch_sz, 2*num_units]
inputs_mul_c = tf.complex( inputs_mul[:, :self._num_units],
inputs_mul[:, self._num_units:] )
# [batch_sz, num_units]
# prepare state linear combination (always complex!)
state_c = tf.complex( state[:, :self._num_units],
state[:, self._num_units:] )
state_mul = self.D1.mul(state_c)
state_mul = FFT(state_mul)
state_mul = self.R1.mul(state_mul)
state_mul = self.P.mul(state_mul)
state_mul = self.D2.mul(state_mul)
state_mul = IFFT(state_mul)
state_mul = self.R2.mul(state_mul)
state_mul = self.D3.mul(state_mul)
# [batch_sz, num_units]
# calculate preactivation
preact = inputs_mul_c + state_mul
# [batch_sz, num_units]
new_state_c = modReLU(preact, self.b_h) # [batch_sz, num_units] C
new_state = tf.concat([tf.real(new_state_c), tf.imag(new_state_c)], 1) # [batch_sz, 2*num_units] R
# outside network (last dense layer) is ready for 2*num_units -> num_out
output = new_state
# print("cell.call output:", output.shape, output.dtype)
# print("cell.call new_state:", new_state.shape, new_state.dtype)
return output, new_state
评论列表
文章目录