train.py 文件源码

python
阅读 24 收藏 0 点赞 0 评论 0

项目:Particle-Picking-Cryo-EM 作者: hqythu 项目源码 文件源码
def main():
    parser = argparse.ArgumentParser(description='Train a neural network')

    parser.add_argument('--model', type=str)
    parser.add_argument('--lr', type=float, default=0.001)
    parser.add_argument('--decay', type=float, default=1e-4)
    parser.add_argument('--momentum', type=float, default=0.9)
    parser.add_argument('--batch', type=int, default=128)
    parser.add_argument('--epoch', type=int, default=100)
    parser.add_argument('--output', type=str, default='weight')
    args = parser.parse_args()

    model = importlib.import_module(args.model).build()

    six.print_('loading data')
    (train_x, train_y, val_x, val_y) = load_data()
    six.print_('load data complete')

    sgd = SGD(lr=args.lr,
              decay=args.decay,
              momentum=args.momentum,
              nesterov=True)
    model.compile(loss='binary_crossentropy', optimizer=sgd,
                  metrics=['accuracy'])
    six.print_('build model complete')

    six.print_('start training')
    model.fit(train_x, train_y, batch_size=args.batch, nb_epoch=args.epoch,
              verbose=2,
              shuffle=True,
              validation_data=(val_x, val_y))
    model.save_weights(args.output + '.hdf5')
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号