def validate(model_hdf5, net_type, generator_fn_str, dataset_file, featurized=True):
from keras.models import load_model
import data
d_test = Dataset(dataset_file + 'test.pgn')
X_val, y_val = d_test.load(generator_fn_str,
featurized = featurized,
refresh = False,
board = "both")
boards = data.board_from_state(X_val)
if net_type == "from":
model_from = load_model("saved/" + model_hdf5)
y_hat_from = model_from.predict(X_val)
num_correct = 0
for i in range(len(boards)):
if y_val[0][i,np.argmax(y_hat_from[i])] > 0:
num_correct += 1
print(num_correct / len(boards))
elif net_type == "to":
model_to = load_model("saved/" + model_hdf5)
y_hat_to = model_to.predict([X_val, y_val[0].reshape(y_val[0].shape[0],1,X_val.shape[2],X_val.shape[3])])
num_correct = 0
for i in range(len(boards)):
if y_val[1][i,np.argmax(y_hat_to[i])] > 0:
num_correct += 1
print(num_correct / len(boards))
elif net_type == "from_to":
model_from = load_model("saved/" + model_hdf5[0])
model_to = load_model("saved/" + model_hdf5[1])
y_hat_from = model_from.predict(X_val)
for i in range(len(boards)):
from_square = np.argmax(y_hat_from[i])
y_max_from = np.zeros((1,1,X_val.shape[2],X_val.shape[3]))
y_max_from.flat[from_square] = 1
y_hat_to = model_to.predict([np.expand_dims(X_val[i], 0), y_max_from])
to_square = np.argmax(y_hat_to)
move_attempt = data.move_from_action(from_square, to_square)
if boards[i].is_legal(move_attempt):
print("YAY")
else:
print("BOO")
print(move_attempt)
move = data.move_from_action(np.argmax(y_val[0]), np.argmax(y_val[1]))
print(move)
评论列表
文章目录