def gen_valid_move(move_index, label_map, army_map, dims):
"""Generate the top valid move given an output from network"""
x1, y1, x2, y2 = 0, 0, 0, 0
move_half = False
for i in range(moves.shape[0]):
move = moves[i]
if action_mask[move] == 0:
break
move_type, y1, x1 = np.unravel_index(move, (8, dims[0], dims[1]))
index = move_type % 4
if index == 0:
x2, y2 = x1, y1 + 1
elif index == 1:
x2, y2 = x1 + 1, y1
elif index == 2:
x2, y2 = x1, y1 - 1
elif index == 3:
x2, y2 = x1 - 1, y1
move_half = True if move_type >= 4 else False
if y2 < 0 or y2 >= dims[0] or x2 < 0 or x2 >= dims[1]:
continue
if not (
label_map[
y2,
x2] == generals.MOUNTAIN) and (
army_map[
y1,
x1] > 1):
break
return x1, y1, x2, y2, move_half
评论列表
文章目录