def get_model(model_name, n_classes):
''' initialize model '''
if model_name == "SimpleCNN":
model = L.Classifier(SimpleCNN(n_classes=n_classes), lossfun=F.softmax_cross_entropy)
elif model_name == "MiddleCNN":
model = L.Classifier(MiddleCNN(n_classes=n_classes), lossfun=F.softmax_cross_entropy)
else:
raise ValueError('Unknown model name: {}'.format(model_name))
return model
评论列表
文章目录