def prepare_oae_PU4(known_transisitons):
print("Learn from pre + action label",
"*** INCOMPATIBLE MODEL! ***",
sep="\n")
N = known_transisitons.shape[1] // 2
y = generate_oae_action(known_transisitons)
ind = np.where(np.squeeze(combined(y[:,N:])) > 0.5)[0]
y = y[ind]
actions = oae.encode_action(known_transisitons, batch_size=1000).round()
positive = np.concatenate((known_transisitons[:,:N], np.squeeze(actions)), axis=1)
actions = oae.encode_action(y, batch_size=1000).round()
negative = np.concatenate((y[:,:N], np.squeeze(actions)), axis=1)
# random.shuffle(negative)
# negative = negative[:len(positive)]
# normalize
return (default_networks['PUDiscriminator'], *prepare_binary_classification_data(positive, negative))
评论列表
文章目录