layers.py 文件源码

python
阅读 22 收藏 0 点赞 0 评论 0

项目:BiMPM_keras 作者: ijinmao 项目源码 文件源码
def __init__(self, rnn_dim, rnn_unit='gru', input_shape=(0,),
                 dropout=0.0, highway=False, return_sequences=False,
                 dense_dim=0):
        if rnn_unit == 'gru':
            rnn = GRU
        else:
            rnn = LSTM
        self.model = Sequential()
        self.model.add(
            Bidirectional(rnn(rnn_dim,
                              dropout=dropout,
                              recurrent_dropout=dropout,
                              return_sequences=return_sequences),
                          input_shape=input_shape))
        # self.model.add(rnn(rnn_dim,
        #                    dropout=dropout,
        #                    recurrent_dropout=dropout,
        #                    return_sequences=return_sequences,
        #                    input_shape=input_shape))
        if highway:
            if return_sequences:
                self.model.add(TimeDistributed(Highway(activation='tanh')))
            else:
                self.model.add(Highway(activation='tanh'))

        if dense_dim > 0:
            self.model.add(TimeDistributed(Dense(dense_dim,
                                                 activation='relu')))
            self.model.add(TimeDistributed(Dropout(dropout)))
            self.model.add(TimeDistributed(BatchNormalization()))
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号