def prepare_oae_PU3(known_transisitons):
print("discriminate the correct transitions and the other transitions generated by OAE,",
" filtered by the learned state discriminator",
sep="\n")
N = known_transisitons.shape[1] // 2
y = generate_oae_action(known_transisitons)
print("removing invalid successors (sd3)")
ind = np.where(np.squeeze(combined(y[:,N:])) > 0.5)[0]
y = y[ind]
if len(known_transisitons) > 100:
y = y[:len(known_transisitons)] # undersample
print("valid:",len(known_transisitons),"mixed:",len(y),)
print("creating binary classification labels")
return (default_networks['PUDiscriminator'], *prepare_binary_classification_data(known_transisitons, y))
################################################################
# training parameters
评论列表
文章目录