def main():
import numpy.random as random
from trace import trace
import sys
if len(sys.argv) == 1:
sys.exit("{} [directory]".format(sys.argv[0]))
directory = sys.argv[1]
directory_ad = "{}_ad/".format(directory)
print("loading the Discriminator", end='...', flush=True)
ad = Discriminator(directory_ad).load()
print("done.")
# valid_states = load("{}/states.csv".format(directory))
valid_actions = load("{}/actions.csv".format(directory))
threshold = maxdiff(valid_actions)
print("maxdiff:",threshold)
states = load("{}/generated_states.csv".format(directory))
path = "{}/generated_actions.csv".format(directory)
total = states.shape[0]
N = states.shape[1]
acc = 0
try:
print(path)
with open(path, 'wb') as f:
for i,s in enumerate(states):
print("Iteration {}/{} base: {}".format(i,total,i*total), end=' ')
diff = np.sum(np.abs(states - s),axis=1)
neighbors = states[np.where(diff<threshold)]
tmp_actions = np.pad(neighbors,((0,0),(0,N)),"constant")
tmp_actions[:,N:] = s
ys = ad.discriminate(tmp_actions,batch_size=400000)
valid_actions = tmp_actions[np.where(ys > 0.8)]
acc += len(valid_actions)
print(len(neighbors),len(valid_actions),acc)
np.savetxt(f,valid_actions,"%d")
except KeyboardInterrupt:
print("dump stopped")
评论列表
文章目录