def main():
import sys
if len(sys.argv) == 1:
sys.exit("{} [directory]".format(sys.argv[0]))
directory = sys.argv[1]
sd = Discriminator("{}/_sd".format(directory)).load()
ae = ConvolutionalGumbelAE2(directory).load()
input = "generated_states.csv"
print("loading {}".format("{}/{}".format(directory,input)), end='...', flush=True)
states = np.loadtxt("{}/{}".format(directory,input),dtype=np.uint8)
print("done.")
zs = states.view()
total = states.shape[0]
N = states.shape[1]
batch = 500000
output = "generated_states2.csv"
try:
print(ae.local(output))
with open(ae.local(output), 'wb') as f:
print("original states:",total)
for i in range(total//batch+1):
_zs = zs[i*batch:(i+1)*batch]
_result = sd.discriminate(_zs,batch_size=5000).round().astype(np.uint8)
_zs_filtered = _zs[np.where(_result > 0)[0],:]
print("reduced states:",len(_zs_filtered),"/",len(_zs))
_xs = ae.decode_binary(_zs_filtered[:20],batch_size=5000).round().astype(np.uint8)
ae.plot(_xs,path="generated_states_filtered{}.png".format(i))
np.savetxt(f,_zs_filtered,"%d",delimiter=" ")
except KeyboardInterrupt:
print("dump stopped")
评论列表
文章目录