def get(activation):
if activation.__class__.__name__ == 'str':
if activation in ['sigmoid', 'Sigmoid']:
return Sigmoid()
if activation in ['tan', 'tanh', 'Tanh']:
return Tanh()
if activation in ['relu', 'ReLU', 'RELU']:
return ReLU()
if activation in ['linear', 'Linear']:
return Linear()
if activation in ['softmax', 'Softmax']:
return Softmax()
if activation in ['elliot', 'Elliot']:
return Elliot()
if activation in ['symmetric_elliot', 'SymmetricElliot']:
return SymmetricElliot()
if activation in ['SoftPlus', 'soft_plus', 'softplus']:
return SoftPlus()
if activation in ['SoftSign', 'softsign', 'soft_sign']:
return SoftSign()
raise ValueError('Unknown activation name: {}.'.format(activation))
elif isinstance(activation, Activation):
return copy.deepcopy(activation)
else:
raise ValueError("Unknown type: {}.".format(activation.__class__.__name__))
评论列表
文章目录