util.py 文件源码

python
阅读 17 收藏 0 点赞 0 评论 0

项目:chess-deep-rl 作者: rajpurkar 项目源码 文件源码
def validate(model_hdf5, net_type, generator_fn_str, dataset_file, featurized=True):
    from keras.models import load_model
    import data

    d_test = Dataset(dataset_file + 'test.pgn')
    X_val, y_val = d_test.load(generator_fn_str,
        featurized = featurized,
        refresh    = False,
        board      = "both")
    boards = data.board_from_state(X_val)

    if net_type == "from":
        model_from = load_model("saved/" + model_hdf5)
        y_hat_from = model_from.predict(X_val)
        num_correct = 0
        for i in range(len(boards)):
            if y_val[0][i,np.argmax(y_hat_from[i])] > 0:
                num_correct += 1
        print(num_correct / len(boards))

    elif net_type == "to":
        model_to = load_model("saved/" + model_hdf5)
        y_hat_to = model_to.predict([X_val, y_val[0].reshape(y_val[0].shape[0],1,X_val.shape[2],X_val.shape[3])])
        num_correct = 0
        for i in range(len(boards)):
            if y_val[1][i,np.argmax(y_hat_to[i])] > 0:
                num_correct += 1
        print(num_correct / len(boards))

    elif net_type == "from_to":
        model_from = load_model("saved/" + model_hdf5[0])
        model_to = load_model("saved/" + model_hdf5[1])
        y_hat_from = model_from.predict(X_val)

        for i in range(len(boards)):
            from_square = np.argmax(y_hat_from[i])
            y_max_from = np.zeros((1,1,X_val.shape[2],X_val.shape[3]))
            y_max_from.flat[from_square] = 1

            y_hat_to = model_to.predict([np.expand_dims(X_val[i], 0), y_max_from])
            to_square = np.argmax(y_hat_to)
            move_attempt = data.move_from_action(from_square, to_square)
            if boards[i].is_legal(move_attempt):
                print("YAY")
            else:
                print("BOO")
            print(move_attempt)
            move = data.move_from_action(np.argmax(y_val[0]), np.argmax(y_val[1]))
            print(move)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号