params.py 文件源码

python
阅读 26 收藏 0 点赞 0 评论 0

项目:AMBR 作者: Algomorph 项目源码 文件源码
def __init__(self, hidden_unit_count=None, archive=None):
            if archive is None and hidden_unit_count is None:
                raise ValueError(
                    "If archive is not passed in, an " + Parameters.LSTM.__name__ +
                    " object needs hidden_unit_count argument to be an integer.")
            if archive is None:
                gen__r_o_v = generate_random_orthogonal_vectors
                self.input_weights = theano.shared(np.concatenate([gen__r_o_v(hidden_unit_count),
                                                                   gen__r_o_v(hidden_unit_count),
                                                                   gen__r_o_v(hidden_unit_count),
                                                                   gen__r_o_v(hidden_unit_count)], axis=1),
                                                   self.input_weights_literal)  # formerly lstm_W
                self.hidden_weights = theano.shared(np.concatenate([gen__r_o_v(hidden_unit_count),
                                                                    gen__r_o_v(hidden_unit_count),
                                                                    gen__r_o_v(hidden_unit_count),
                                                                    gen__r_o_v(hidden_unit_count)], axis=1),
                                                    self.hidden_weights_literal)  # formerly lstm_U

                self.bias = theano.shared(np.zeros((4 * hidden_unit_count,)).astype(config.floatX),
                                          self.bias_literal)  # formerly lstm_b
            else:
                self.load_values_from_dict(archive)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号