model_common.py 文件源码

python
阅读 31 收藏 0 点赞 0 评论 0

项目:LSTMVAE 作者: ashwatthaman 项目源码 文件源码
def encode(self,xs):
        xs = [x + [2] for x in xs]  # 1?<s>????dec??<s>??????
        xs_f = self.makeEmbedBatch(xs)
        xs_b = self.makeEmbedBatch(xs, True)

        self.enc_f.reset_state()
        self.enc_b.reset_state()
        ys_f = self.enc_f(xs_f)
        ys_b = self.enc_b(xs_b)

        # VAE
        mu_arr = [self.le2_mu(F.concat((hx_f, cx_f, hx_b, cx_b))) for hx_f, cx_f, hx_b, cx_b in
                  zip(self.enc_f.hx, self.enc_f.cx, self.enc_b.hx, self.enc_b.cx)]
        var_arr = [self.le2_ln_var(F.concat((hx_f, cx_f, hx_b, cx_b))) for hx_f, cx_f, hx_b, cx_b in
                   zip(self.enc_f.hx, self.enc_f.cx, self.enc_b.hx, self.enc_b.cx)]
        return mu_arr,var_arr
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号