recurrent_ref.py 文件源码

python
阅读 35 收藏 0 点赞 0 评论 0

项目:ngraph 作者: NervanaSystems 项目源码 文件源码
def __init__(self, in_size, hidden_size, encoder_activation='tanh',
                 decoder_activation='tanh', decoder_return_sequence=True):

        assert encoder_activation in ('tanh', 'identity', ), "invalid encoder_activation"
        self.encoder_activation = encoder_activation
        assert decoder_activation in ('tanh', 'identity', ), "invalid decoder_activation"
        self.decoder_activation = decoder_activation

        self.hidden_size = hidden_size
        self.in_size = in_size
        # encoder
        self.Wxh_enc = np.zeros((hidden_size, in_size))  # input to hidden
        self.Whh_enc = np.zeros((hidden_size, hidden_size))  # hidden to hidden
        self.bh_enc = np.zeros((hidden_size, 1))  # hidden bias
        # decoder
        self.Wxh_dec = np.zeros((hidden_size, in_size))  # input to hidden
        self.Whh_dec = np.zeros((hidden_size, hidden_size))  # hidden to hidden
        self.bh_dec = np.zeros((hidden_size, 1))  # hidden bias
        self.decoder_return_sequence = decoder_return_sequence
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号