def __ConvLSTMStep(
name,
seq_len,
input_dim,
hidden_dim,
current_input,
last_hidden,
last_cell,
dilation_depth=10,
inp_bias_init=0.,
forget_bias_init=3.,
out_bias_init=0.,
g_bias_init=0.):
# X_t*(U^i, U^f, U^o, U^g)
dilations = [2**i for i in xrange(dilation_depth)]
prev_conv = current_input
last_cell_stack = T.concatenate((last_cell,last_cell),axis=1)
for i,value in enumerate(dilations):
#prev_conv = lib.ops.conv1d(name+".WaveNetConv%d"%(i+1),prev_conv,2,1,hidden_dim,input_dim,True,False,pad=(dilation,0),filter_dilation=(dilation,1))[:,:,:current_input.shape[2],:]
prev_conv,y = lib.ops.WaveNetConv1d("WaveNetBlock-%d"%(i+1),prev_conv,2,hidden_dim,input_dim,bias=True,batchnorm=False,dilation=value)
prev_conv = T.concatenate((prev_conv,last_hidden),axis=1)
prev_conv = lib.ops.conv1d(name+".ConvGates",prev_conv,1,1,4*hidden_dim,2*input_dim,True,False)
W_cell = lib.param(name+'.CellWeights',lasagne.init.HeNormal().sample((3*hidden_dim,seq_len,1)))
inp_forget = T.nnet.sigmoid(prev_conv[:,:2*hidden_dim] + W_cell[:2*hidden_dim]*last_cell_stack)
i_t = inp_forget[:,:hidden_dim]
f_t = inp_forget[:,hidden_dim:]
C_t = f_t*last_cell + i_t*T.tanh(prev_conv[:,2*hidden_dim:3*hidden_dim])
o_t = T.nnet.sigmoid(prev_conv[:,3*hidden_dim:]+W_cell[2*hidden_dim:]*C_t)
H_t = o_t*T.tanh(C_t)
return H_t,C_t
评论列表
文章目录