def call(self, x, mask=None):
input_dim = self.input_dim
input_type='real'
out_every_t=False
loss_function='MSE'
output_type='real'
flag_feed_forward=False
flag_use_mask=False
hidden_bias_mean=np.float32(0.0)
hidden_bias_init='zero'
Wimpl=self.unitary_impl
if ('full' in Wimpl):
Wimpl='full'
elif (Wimpl=='ASB2016'):
Wimpl='adhoc'
#hidden_bias_init='rand'
elif (Wimpl=='ASB2016_fast'):
Wimpl='adhoc_fast'
n_layers=1
seed=1234
x_spec=K.permute_dimensions(x,(1,0,2))
inputs, parameters, costs = models.complex_RNN(input_dim, self.hidden_dim, self.output_dim, input_type=input_type,out_every_t=out_every_t, loss_function=loss_function,output_type=output_type,flag_feed_forward=flag_feed_forward,flag_return_lin_output=True,x_spec=x_spec,flag_use_mask=flag_use_mask,hidden_bias_mean=hidden_bias_mean,Wimpl=Wimpl,flag_return_hidden_states=True,n_layers=n_layers,seed=seed,hidden_bias_init=hidden_bias_init)
lin_output=costs[2]
#self.hidden_states=costs[3]
if (self.unitary_impl=='full'):
# just use lrng for learning rate on this parameter
parameters[-1].name+='full_natGrad'
elif (self.unitary_impl=='full_natGrad'):
# use fixed lrng with natural gradient update
parameters[-1].name+='_natGrad_unitaryAug'
elif (self.unitary_impl=='full_natGradRMS'):
# use fixed lrng with natural gradient update and RMSprop-style gradient adjustment
parameters[-1].name+='_natGradRMS_unitaryAug'
elif (self.unitary_impl=='full_enforceComplex'):
# swap out 2Nx2N augmented unitary matrix for Nx2N, which ensures the
# complex number constraint is satisfied
parameters[-1].name+='full_natGrad'
Waug=parameters[-1]
WReIm=K.variable(value=Waug[:Waug.shape[1]/2,:].eval(),name=Waug.name)
WaugFull=K.concatenate( (WReIm, K.concatenate((-WReIm[:,WReIm.shape[1]/2:],WReIm[:,:WReIm.shape[1]/2]),axis=1)),axis=0 )
lin_output_new = theano.clone(lin_output,replace={parameters[-1]:WaugFull})
lin_output = lin_output_new
parameters[-1]=WReIm
self.trainable_weights = parameters
return lin_output
评论列表
文章目录