rtn_test.py 文件源码

python
阅读 18 收藏 0 点赞 0 评论 0

项目:ikelos 作者: braingineer 项目源码 文件源码
def test1():
    seq_size = 10
    batch_size = 10 
    rnn_size = 1
    xin = Input(batch_shape=(batch_size, seq_size,1))
    xtop = Input(batch_shape=(batch_size, seq_size))
    xbranch, xsummary = RTTN(rnn_size, return_sequences=True)([xin, xtop])

    model = Model(input=[xin, xtop], output=[xbranch, xsummary])
    model.compile(loss='MSE', optimizer='SGD')
    data_gen = generate_data_batch(batch_size, seq_size)
    model.fit_generator(generator=data_gen, samples_per_epoch=1000, nb_epoch=100)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号