action_discriminator.py 文件源码

python
阅读 43 收藏 0 点赞 0 评论 0

项目:latplan 作者: guicho271828 项目源码 文件源码
def prepare_oae_PU4(known_transisitons):
    print("Learn from pre + action label",
          "*** INCOMPATIBLE MODEL! ***",
          sep="\n")
    N = known_transisitons.shape[1] // 2

    y = generate_oae_action(known_transisitons)

    ind = np.where(np.squeeze(combined(y[:,N:])) > 0.5)[0]

    y = y[ind]

    actions = oae.encode_action(known_transisitons, batch_size=1000).round()
    positive = np.concatenate((known_transisitons[:,:N], np.squeeze(actions)), axis=1)
    actions = oae.encode_action(y, batch_size=1000).round()
    negative = np.concatenate((y[:,:N], np.squeeze(actions)), axis=1)
    # random.shuffle(negative)
    # negative = negative[:len(positive)]
    # normalize
    return (default_networks['PUDiscriminator'], *prepare_binary_classification_data(positive, negative))
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号