action_discriminator.py 文件源码

python
阅读 25 收藏 0 点赞 0 评论 0

项目:latplan 作者: guicho271828 项目源码 文件源码
def prepare_oae_PU3(known_transisitons):
    print("discriminate the correct transitions and the other transitions generated by OAE,",
          " filtered by the learned state discriminator",
          sep="\n")
    N = known_transisitons.shape[1] // 2
    y = generate_oae_action(known_transisitons)

    print("removing invalid successors (sd3)")
    ind = np.where(np.squeeze(combined(y[:,N:])) > 0.5)[0]

    y = y[ind]
    if len(known_transisitons) > 100:
        y = y[:len(known_transisitons)] # undersample

    print("valid:",len(known_transisitons),"mixed:",len(y),)
    print("creating binary classification labels")
    return (default_networks['PUDiscriminator'], *prepare_binary_classification_data(known_transisitons, y))

################################################################
# training parameters
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号