state_discriminator.py 文件源码

python
阅读 29 收藏 0 点赞 0 评论 0

项目:latplan 作者: guicho271828 项目源码 文件源码
def prepare(data_valid):
    print(data_valid.shape)
    batch = data_valid.shape[0]
    N = data_valid.shape[1]
    data_invalid = np.random.randint(0,2,(batch,N),dtype=np.int8)
    print(data_valid.shape,data_invalid.shape)
    ai = data_invalid.view([('', data_invalid.dtype)] * N)
    av = data_valid.view  ([('', data_valid.dtype)]   * N)
    data_invalid = np.setdiff1d(ai, av).view(data_valid.dtype).reshape((-1, N))

    return prepare_binary_classification_data(data_valid, data_invalid)

# default values
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号