predict.py 文件源码

python
阅读 21 收藏 0 点赞 0 评论 0

项目:speed 作者: keon 项目源码 文件源码
def main(arg):
    resize = (200, 66)

    # initialize dataset
    dataset = Dataset(arg.test_folder,
                      resize=resize,
                      batch_size=1,
                      timesteps=arg.timesteps,
                      windowsteps=1,
                      shift=0,
                      train=False)
    print('[!] testing dataset samples: %d' % len(dataset.data))

    # initialize model
    cuda = th.cuda.is_available()
    models = init_models(arg.model, n=3, lr=0, restore=True, cuda=cuda)

    # Initiate Prediction
    t0 = datetime.datetime.now()
    try:
        predict(models, dataset, arg, cuda=cuda)
    except KeyboardInterrupt:
        print('[!] KeyboardInterrupt: Stopped Training...')
    t1 = datetime.datetime.now()

    print('[!] Finished Training, Time Taken4 %s' % (t1-t0))
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号