train.py 文件源码

python
阅读 17 收藏 0 点赞 0 评论 0

项目:temporal-attention 作者: dosht 项目源码 文件源码
def main():
    print("\n\nLoading data...")
    data_dir = "/data/translate"
    vocab_size = 20000
    en, fr = prepare_date(data_dir, vocab_size)

    print("\n\nbuilding the model...")
    embedding_size = 64
    hidden_size = 32
    model = Sequential()
    model.add(Embedding(en.max_features, embedding_size, input_length=en.max_length, mask_zero=True))
    model.add(Bidirectional(GRU(hidden_size), merge_mode='sum'))
    model.add(RepeatVector(fr.max_length))
    model.add(GRU(embedding_size))
    model.add(Dense(fr.max_length, activation="softmax"))
    model.compile('rmsprop', 'mse')
    print(model.get_config())

    print("\n\nFitting the model...")
    model.fit(en.examples, fr.examples)

    print("\n\nEvaluation...")
    #TODO
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号