ram.py 文件源码

python
阅读 24 收藏 0 点赞 0 评论 0

项目:ram 作者: amasky 项目源码 文件源码
def __init__(
        self, g_size=8, n_steps=6, n_scales=1, var=0.03, use_lstm=False
    ):
        d_glm = 128
        d_core = 256
        super(RAM, self).__init__(
            emb_l=L.Linear(2, d_glm),
            emb_x=L.Linear(g_size*g_size*n_scales, d_glm),
            fc_lg=L.Linear(d_glm, d_core),
            fc_xg=L.Linear(d_glm, d_core),
            fc_ha=L.Linear(d_core, 10),
            fc_hl=L.Linear(d_core, 2),
            fc_hb=L.Linear(d_core, 1),
        )

        if use_lstm:
            self.add_link(name='core_lstm', link=L.LSTM(d_core, d_core))
        else:
            self.add_link(name='core_hh', link=L.Linear(d_core, d_core))
            self.add_link(name='core_gh', link=L.Linear(d_core, d_core))

        self.use_lstm = use_lstm
        self.d_core = d_core
        self.g_size = g_size
        self.n_steps = n_steps
        self.n_scales = n_scales
        self.var = var
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号