def encode(self,xs):
xs = [x + [2] for x in xs] # 1?<s>????dec??<s>??????
xs_f = self.makeEmbedBatch(xs)
xs_b = self.makeEmbedBatch(xs, True)
self.enc_f.reset_state()
self.enc_b.reset_state()
ys_f = self.enc_f(xs_f)
ys_b = self.enc_b(xs_b)
# VAE
mu_arr = [self.le2_mu(F.concat((hx_f, cx_f, hx_b, cx_b))) for hx_f, cx_f, hx_b, cx_b in
zip(self.enc_f.hx, self.enc_f.cx, self.enc_b.hx, self.enc_b.cx)]
var_arr = [self.le2_ln_var(F.concat((hx_f, cx_f, hx_b, cx_b))) for hx_f, cx_f, hx_b, cx_b in
zip(self.enc_f.hx, self.enc_f.cx, self.enc_b.hx, self.enc_b.cx)]
return mu_arr,var_arr
评论列表
文章目录