book_train.py 文件源码

python
阅读 23 收藏 0 点赞 0 评论 0

项目:deep-murasaki 作者: lazydroid 项目源码 文件源码
def train():
    X, m = get_data(['x', 'm'])
#   X_train, X_test, m_train, m_test = get_data(['x', 'm'])
#   for board in X_train[:2] :
#       show_board( board )

    start = time.time()
    print 'shuffling...',
    idx = range(len(X))
    random.shuffle(idx)
    X, m = X[idx], m[idx]
    print '%.2f sec' % (time.time() - start)

    # unpack the bits
    start = time.time()
    print 'unpacking...',
    X = np.array([numpy.unpackbits(x).reshape(28, 8, 8).astype(np.bool) for x in X])
    print '%.2f sec' % (time.time() - start)

    model, name = make_model()

    print 'compiling...'    # 5e5 too high on 2017-09-06
    sgd = SGD(lr=3e-5, decay=1e-6, momentum=0.9, nesterov=True) # 1e-4 : nan, 1e-5 loss 137 epoch1, 5e-5 loss 121 epoch1
#   model.compile(loss='squared_hinge', optimizer='adadelta')
#   model.compile(loss='mean_squared_error', optimizer='adadelta')
    model.compile(loss='mean_squared_error', optimizer=sgd)

    early_stopping = EarlyStopping( monitor = 'loss', patience = 50 )   # monitor='val_loss', verbose=0, mode='auto'
    #print 'fitting...'
    history = model.fit( X, m, nb_epoch = 10, batch_size = BATCH_SIZE, validation_split=0.05)   #, callbacks = [early_stopping])    #, validation_split=0.05)   #, verbose=2)   #, show_accuracy = True )

#   print 'evaluating...'
#   score = model.evaluate(X_test, m_test, batch_size = BATCH_SIZE )
#   print 'score:', score

    now = datetime.datetime.now()
    suffix = str(now.strftime("%Y-%m-%d_%H%M%S"))
    model.save_weights( name.replace( '.model', '_%s.model' % suffix), overwrite = True )

    #print X_train[:10]
#   print m_train[:20]
#   print model.predict( X_train[:20], batch_size = 5 )
#   print m[:20]
#   print model.predict( X[:20], batch_size = 5 )

    result = zip( m[-20:] * 100.0, model.predict( X[-20:], batch_size = 5 ) * 100.0)
    for a, b in result :
        print '%.4f %.4f %.2f%%' % (a, b, abs(a-b) * 100.0 / max(abs(a),abs(b)))

#   print m_test[:20]
#   print model.predict( X_test[:20], batch_size = 5 )

#   with open( MODEL_DATA + '.history', 'w') as fout :
#       print >>fout, history.losses
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号