def param_init_lstm(config, params, prefix='lstm'):
"""
Init the LSTM parameter and attach it to the exitings params
:see: init_params
"""
# each LSTM cell has 4 weight matrices for input and 4 weight matrices for
# state
W = np.concatenate([ortho_weight(config.dim_proj)]*4, axis=1)
params[_p(prefix, 'W')] = W
U = np.concatenate([ortho_weight(config.dim_proj)]*4, axis=1)
params[_p(prefix, 'U')] = U
b = np.zeros((4 * config.dim_proj,))
params[_p(prefix, 'b')] = b.astype(T_config.floatX)
return params
评论列表
文章目录