15-keras_seq2seq_mod.py 文件源码

python
阅读 20 收藏 0 点赞 0 评论 0

项目:albemarle 作者: SeanTater 项目源码 文件源码
def __init__(self, output_dim, hidden_dim, output_length, depth=1, broadcast_state=True, inner_broadcast_state=True, peek=False, dropout=0.1, **kwargs):
        super(Seq2seq, self).__init__()
        if type(depth) not in [list, tuple]:
            depth = (depth, depth)
        if 'batch_input_shape' in kwargs:
            shape = kwargs['batch_input_shape']
            del kwargs['batch_input_shape']
        elif 'input_shape' in kwargs:
            shape = (None,) + tuple(kwargs['input_shape'])
            del kwargs['input_shape']
        elif 'input_dim' in kwargs:
            shape = (None, None, kwargs['input_dim'])
            del kwargs['input_dim']
        lstms = []
        layer = LSTMEncoder(batch_input_shape=shape, output_dim=hidden_dim, state_input=False, return_sequences=depth[0] > 1, **kwargs)
        self.add(layer)
        lstms += [layer]
        for i in range(depth[0] - 1):
            self.add(Dropout(dropout))
            layer = LSTMEncoder(output_dim=hidden_dim, state_input=inner_broadcast_state, return_sequences=i < depth[0] - 2, **kwargs)
            self.add(layer)
            lstms += [layer]
        if inner_broadcast_state:
            for i in range(len(lstms) - 1):
                lstms[i].broadcast_state(lstms[i + 1])
        encoder = self.layers[-1]
        self.add(Dropout(dropout))
        decoder_type = LSTMDecoder2 if peek else LSTMDecoder
        decoder = decoder_type(hidden_dim=hidden_dim, output_length=output_length, state_input=broadcast_state, **kwargs)
        self.add(decoder)
        lstms = [decoder]
        for i in range(depth[1] - 1):
            self.add(Dropout(dropout))
            layer = LSTMEncoder(output_dim=hidden_dim, state_input=inner_broadcast_state, return_sequences=True, **kwargs)
            self.add(layer)
            lstms += [layer]
        if inner_broadcast_state:
                for i in range(len(lstms) - 1):
                    lstms[i].broadcast_state(lstms[i + 1])
        if broadcast_state:
            encoder.broadcast_state(decoder)
        self.add(Dropout(dropout))
        self.add(TimeDistributed(Dense(output_dim, **kwargs)))
        self.encoder = encoder
        self.decoder = decoder
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号