test_keras.py 文件源码

python
阅读 20 收藏 0 点赞 0 评论 0

项目:wtte-rnn 作者: ragulpr 项目源码 文件源码
def model_masking(discrete_time, init_alpha, max_beta):
    model = Sequential()

    model.add(Masking(mask_value=mask_value,
                      input_shape=(n_timesteps, n_features)))
    model.add(TimeDistributed(Dense(2)))
    model.add(Lambda(wtte.output_lambda, arguments={"init_alpha": init_alpha,
                                                    "max_beta_value": max_beta}))

    if discrete_time:
        loss = wtte.loss(kind='discrete', reduce_loss=False).loss_function
    else:
        loss = wtte.loss(kind='continuous', reduce_loss=False).loss_function

    model.compile(loss=loss, optimizer=RMSprop(
        lr=lr), sample_weight_mode='temporal')
    return model
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号