def __init__(self, model, data_generator, epochs, loss):
self.epochs = epochs
self.model = model
self.data_generator = data_generator
self.loss = loss
if loss == "smoothl1":
self.loss_fn = F.smooth_l1_loss
elif loss == "l1":
self.loss_fn = nn.L1Loss()
elif loss == "l2":
self.loss_fn = nn.MSELoss()
else:
raise ValueError("Unrecognized loss type: {}".format(loss))
评论列表
文章目录