def __init__(self, Nj, gpu, model_file, filename):
# validate arguments.
self.gpu = (gpu >= 0)
if self.gpu and not torch.cuda.is_available():
raise GPUNotFoundError('GPU is not found.')
# initialize model to estimate.
self.model = AlexNet(Nj)
self.model.load_state_dict(torch.load(model_file))
# prepare gpu.
if self.gpu:
self.model.cuda()
# load dataset to estimate.
self.dataset = PoseDataset(
filename,
input_transform=transforms.Compose([
transforms.ToTensor(),
RandomNoise()]),
output_transform=Scale(),
transform=Crop())
评论列表
文章目录