15-keras_seq2seq_mod.py 文件源码

python
阅读 24 收藏 0 点赞 0 评论 0

项目:albemarle 作者: SeanTater 项目源码 文件源码
def build(self, input_shape):
        input_shape = list(input_shape)
        input_shape = input_shape[:1] + [self.output_length] + input_shape[1:]
        if not self.hidden_dim:
            self.hidden_dim = input_shape[-1]
        output_dim = input_shape[-1]
        self.output_dim = self.hidden_dim
        initial_weights = self.initial_weights
        self.initial_weights = None
        super(LSTMDecoder, self).build(input_shape)
        self.output_dim = output_dim
        self.initial_weights = initial_weights
        self.W_y = self.init((self.hidden_dim, self.output_dim), name='{}_W_y'.format(self.name))
        self.b_y = K.zeros((self.output_dim), name='{}_b_y'.format(self.name))
        self.trainable_weights += [self.W_y, self.b_y]
        if self.initial_weights is not None:
            self.set_weights(self.initial_weights)
            del self.initial_weights
        input_shape.pop(1)
        self.input_spec = [InputSpec(shape=tuple(input_shape))]
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号