test.py 文件源码

python
阅读 18 收藏 0 点赞 0 评论 0

项目:sudoku 作者: Kyubyong 项目源码 文件源码
def test():
    x, y = load_data(type="test")

    g = Graph(is_training=False)
    with g.graph.as_default():    
        sv = tf.train.Supervisor()
        with sv.managed_session(config=tf.ConfigProto(allow_soft_placement=True)) as sess:
            # Restore parameters
            sv.saver.restore(sess, tf.train.latest_checkpoint(hp.logdir))
            print("Restored!")

            # Get model name
            mname = open(hp.logdir + '/checkpoint', 'r').read().split('"')[1] # model name

        if not os.path.exists('results'): os.mkdir('results')
            fout = 'results/{}.txt'.format(mname)
            import copy
            _preds = copy.copy(x)
            while 1:
                istarget, probs, preds = sess.run([g.istarget, g.probs, g.preds], {g.x:_preds, g.y: y})
                probs = probs.astype(np.float32)
                preds = preds.astype(np.float32)

                probs *= istarget #(N, 9, 9)
                preds *= istarget #(N, 9, 9)

                probs = np.reshape(probs, (-1, 9*9)) #(N, 9*9)
                preds = np.reshape(preds, (-1, 9*9))#(N, 9*9)

                _preds = np.reshape(_preds, (-1, 9*9))
                maxprob_ids = np.argmax(probs, axis=1) # (N, ) <- blanks of the most probable prediction
                maxprobs = np.max(probs, axis=1, keepdims=False)
                for j, (maxprob_id, maxprob) in enumerate(zip(maxprob_ids, maxprobs)):
                    if maxprob != 0:
                        _preds[j, maxprob_id] = preds[j, maxprob_id]
                _preds = np.reshape(_preds, (-1, 9, 9))
                _preds = np.where(x==0, _preds, y) # # Fill in the non-blanks with correct numbers

                if np.count_nonzero(_preds) == _preds.size: break

            write_to_file(x.astype(np.int32), y, _preds.astype(np.int32), fout)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号