classifier.py 文件源码

python
阅读 30 收藏 0 点赞 0 评论 0

项目:SentEval 作者: facebookresearch 项目源码 文件源码
def prepare_split(self, X, y, validation_data=None, validation_split=None):
        # Preparing validation data
        assert validation_split or validation_data
        if validation_data is not None:
            trainX, trainy = X, y
            devX, devy = validation_data
        else:
            permutation = np.random.permutation(len(X))
            trainidx = permutation[int(validation_split*len(X)):]
            devidx = permutation[0:int(validation_split*len(X))]
            trainX, trainy = X[trainidx], y[trainidx]
            devX, devy = X[devidx], y[devidx]

        if not self.cudaEfficient:
            trainX = torch.FloatTensor(trainX).cuda()
            trainy = torch.LongTensor(trainy).cuda()
            devX = torch.FloatTensor(devX).cuda()
            devy = torch.LongTensor(devy).cuda()
        else:
            trainX = torch.FloatTensor(trainX)
            trainy = torch.LongTensor(trainy)
            devX = torch.FloatTensor(devX)
            devy = torch.LongTensor(devy)

        return trainX, trainy, devX, devy
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号