def test1():
seq_size = 10
batch_size = 10
rnn_size = 1
xin = Input(batch_shape=(batch_size, seq_size,1))
xtop = Input(batch_shape=(batch_size, seq_size))
xbranch, xsummary = RTTN(rnn_size, return_sequences=True)([xin, xtop])
model = Model(input=[xin, xtop], output=[xbranch, xsummary])
model.compile(loss='MSE', optimizer='SGD')
data_gen = generate_data_batch(batch_size, seq_size)
model.fit_generator(generator=data_gen, samples_per_epoch=1000, nb_epoch=100)
评论列表
文章目录