sequence_blocks.py 文件源码

python
阅读 20 收藏 0 点赞 0 评论 0

项目:Neural-Chatbot 作者: saurabhmathur96 项目源码 文件源码
def AttentionDecoder(hidden_size, activation=None, return_sequences=True, bidirectional=False, use_gru=True):
    if activation is None:
        activation = ELU()
    if use_gru:
        def _decoder(x, attention):
            if bidirectional:
                branch_1 = AttentionWrapper(GRU(int(hidden_size/2), activation='linear', return_sequences=return_sequences,
                                                go_backwards=False), attention, single_attention_param=True)(x)
                branch_2 = AttentionWrapper(GRU(int(hidden_size/2), activation='linear', return_sequences=return_sequences,
                                                go_backwards=True), attention, single_attention_param=True)(x)
                x = concatenate([branch_1, branch_2])
                return activation(x)
            else:
                x = AttentionWrapper(GRU(hidden_size, activation='linear',
                                         return_sequences=return_sequences), attention, single_attention_param=True)(x)
                x = activation(x)
                return x
    else:
        def _decoder(x, attention):
            if bidirectional:
                branch_1 = AttentionWrapper(LSTM(int(hidden_size/2), activation='linear', return_sequences=return_sequences,
                                                 go_backwards=False), attention, single_attention_param=True)(x)
                branch_2 = AttentionWrapper(LSTM(hidden_size, activation='linear', return_sequences=return_sequences,
                                                go_backwards=True), attention, single_attention_param=True)(x)
                x = concatenate([branch_1, branch_2])
                x = activation(x)
                return x
            else:
                x = AttentionWrapper(LSTM(hidden_size, activation='linear', return_sequences=return_sequences),
                                     attention, single_attention_param=True)(x)
                x = activation(x)
                return x

    return _decoder
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号