train.py 文件源码

python
阅读 25 收藏 0 点赞 0 评论 0

项目:tartarus 作者: sergiooramas 项目源码 文件源码
def build_model(config):
    """Builds the cnn."""
    params = config.model_arch
    get_model = getattr(models, 'get_model_'+str(params['architecture']))
    model = get_model(params)
    #model = model_kenun.build_convnet_model(params)
    # Learning setup
    t_params = config.training_params
    sgd = SGD(lr=t_params["learning_rate"], decay=t_params["decay"],
              momentum=t_params["momentum"], nesterov=t_params["nesterov"])
    adam = Adam(lr=0.001, beta_1=0.9, beta_2=0.999, epsilon=1e-08)
    optimizer = eval(t_params['optimizer'])
    metrics = ['mean_squared_error']
    if config.model_arch["final_activation"] == 'softmax':
        metrics.append('categorical_accuracy')
    if t_params['loss_func'] == 'cosine':
        loss_func = eval(t_params['loss_func'])
    else:
        loss_func = t_params['loss_func']
    model.compile(loss=loss_func, optimizer=optimizer,metrics=metrics)

    return model
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号