def bn_model_pytorch():
"""Same as bn_model but with PyTorch."""
import torch
import torch.nn as nn
bounds = (0, 1)
num_classes = 10
class Net(nn.Module):
def forward(self, x):
assert isinstance(x.data, torch.FloatTensor)
x = torch.mean(x, 3)
x = torch.squeeze(x, dim=3)
x = torch.mean(x, 2)
x = torch.squeeze(x, dim=2)
logits = x
return logits
model = Net()
model = PyTorchModel(
model,
bounds=bounds,
num_classes=num_classes,
cuda=False)
return model
评论列表
文章目录