def lyr_lstm(
self, name_,
s_x_, s_cell_, s_hid_,
idim_, hdim_,
axis_=-1,
lyr_linear_=None,
op_act_=T.tanh,
op_gate_=T.nnet.sigmoid):
s_inp = T.join(axis_, s_x_, s_hid_)
if lyr_linear_ is None:
lyr_linear_ = self.lyr_linear
s_gates_lin, s_inp_lin = T.split(
lyr_linear_(name_+'_rec', s_inp, idim_+hdim_, hdim_*4),
[hdim_*3,hdim_], 2, axis=axis_)
s_igate, s_fgate, s_ogate = T.split(op_gate_(s_gates_lin), [hdim_]*3, 3, axis=axis_)
s_cell_tp1 = s_igate*op_act_(s_inp_lin) + s_fgate*s_cell_
s_hid_tp1 = op_act_(s_cell_tp1)*s_ogate
return s_cell_tp1, s_hid_tp1
评论列表
文章目录