recurrent_highway_networks.py 文件源码

python
阅读 18 收藏 0 点赞 0 评论 0

项目:recurrentshop 作者: farizrahman4u 项目源码 文件源码
def RHN(input_dim, hidden_dim, depth):
    # Wrapped model
    inp = Input(batch_shape=(batch_size, input_dim))
    state = Input(batch_shape=(batch_size, hidden_dim))
    drop_mask = Input(batch_shape=(batch_size, hidden_dim))
    # To avoid all zero mask causing gradient to vanish
    inverted_drop_mask = Lambda(lambda x: 1.0 - x, output_shape=lambda s: s)(drop_mask)
    drop_mask_2 = Lambda(lambda x: x + 0., output_shape=lambda s: s)(inverted_drop_mask)
    dropped_state = multiply([state, inverted_drop_mask])
    y, new_state = RHNCell(units=hidden_dim, recurrence_depth=depth,
                           kernel_initializer=weight_init,
                           kernel_regularizer=l2(weight_decay),
                           kernel_constraint=max_norm(gradient_clip),
                           bias_initializer=Constant(transform_bias),
                           recurrent_initializer=weight_init,
                           recurrent_regularizer=l2(weight_decay),
                           recurrent_constraint=max_norm(gradient_clip))([inp, dropped_state])
    return RecurrentModel(input=inp, output=y,
                          initial_states=[state, drop_mask],
                          final_states=[new_state, drop_mask_2])


# lr decay Scheduler
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号