def get_batch_size(cnn_type):
if cnn_type == 'VGG19' or cnn_type == 'VGG19_KERAS':
return 18
if cnn_type == 'VGG16_DROPOUT':
return 15
if cnn_type == 'VGG16' or cnn_type == 'VGG16_KERAS':
return 20
if cnn_type == 'RESNET50':
return 20
if cnn_type == 'INCEPTION_V3':
return 22
if cnn_type == 'SQUEEZE_NET':
return 40
if cnn_type == 'DENSENET_161':
return 8
if cnn_type == 'DENSENET_121':
return 25
return -1
评论列表
文章目录