train_lstm.py 文件源码

python
阅读 26 收藏 0 点赞 0 评论 0

项目:taxi 作者: xuguanggen 项目源码 文件源码
def build_lstm(n_con,n_emb,vocabs_size,n_dis,emb_size,cluster_size):
    hidden_size = 800

    con = Sequential()
    con.add(Dense(input_dim=n_con,output_dim=emb_size))

    emb_list = []
    for i in range(n_emb):
        emb = Sequential()
        emb.add(Embedding(input_dim=vocabs_size[i],output_dim=emb_size,input_length=n_dis))
        emb.add(Flatten())
        emb_list.append(emb)


    in_dimension = 2
    seq = Sequential()
    seq.add(BatchNormalization(input_shape=((MAX_LENGTH,in_dimension))))
    seq.add(Masking([0]*in_dimension,input_shape=(MAX_LENGTH,in_dimension)))
    seq.add(LSTM(emb_size,return_sequences=False,input_shape=(MAX_LENGTH,in_dimension)))

    model = Sequential()
    model.add(Merge([con]+emb_list+[seq],mode='concat'))
    model.add(BatchNormalization())
    model.add(Dense(hidden_size,activation='relu'))
    model.add(Dropout(0.5))
    model.add(Dense(cluster_size,activation='softmax'))
    model.add(Lambda(caluate_point,output_shape=[2]))
    return model
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号