tf_models.py 文件源码

python
阅读 23 收藏 0 点赞 0 评论 0

项目:TemporalConvolutionalNetworks 作者: colincsl 项目源码 文件源码
def BidirLSTM(n_nodes, n_classes, n_feat, max_len=None, 
                causal=True, loss='categorical_crossentropy', optimizer="adam",
                return_param_str=False):

    inputs = Input(shape=(None,n_feat))
    model = LSTM(n_nodes, return_sequences=True)(inputs)

    # Birdirectional LSTM
    if not causal:
        model_backwards = LSTM(n_nodes, return_sequences=True, go_backwards=True)(inputs)
        model = Merge(mode="concat")([model, model_backwards])

    model = TimeDistributed(Dense(n_classes, activation="softmax"))(model)

    model = Model(input=inputs, output=model)
    model.compile(optimizer=optimizer, loss=loss, sample_weight_mode="temporal", metrics=['accuracy'])

    if return_param_str:
        param_str = "LSTM_N{}".format(n_nodes)
        if causal:
            param_str += "_causal"

        return model, param_str
    else:
        return model
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号