model_common.py 文件源码

python
阅读 26 收藏 0 点赞 0 评论 0

项目:LSTMVAE 作者: ashwatthaman 项目源码 文件源码
def decode(self,t_vec,t_pred,wei_arr=None):
        ys_d = self.dec(t_vec)
        ys_w = self.h2w(F.concat(ys_d, axis=0))
        t_all = []
        for t_each in t_pred: t_all += t_each.tolist()
        t_all = xp.array(t_all, dtype=xp.int32)
        if wei_arr is None:
            loss = F.softmax_cross_entropy(ys_w, t_all)  # /len(t_all)
        else:
            sec_arr = np.array([ys_d_e.data.shape[0] for ys_d_e in ys_d[:-1]])
            sec_arr = np.cumsum(sec_arr)
            loss = weighted_cross_entropy(ys_w,t_all,wei_arr,sec_arr)
        # print("t:{}".format([self.vocab.itos(tp_e) for tp_e in t_pred[0].tolist()]))
        # print("y:{}\n".format([self.vocab.itos(int(ys_w.data[ri].argmax())) for ri in range(len(t_pred[0]))]))
        return loss
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号