rass.py 文件源码

python
阅读 28 收藏 0 点赞 0 评论 0

项目:monogreedy 作者: jinjunqi 项目源码 文件源码
def compute(self, state, w_idx, feat, scene):
        # word embedding
        word_vec = self.embedding.compute(w_idx)
        # split states
        e_tm1, c_tm1, h_tm1 = split_state(state, scheme=[(1, self.config['na']), (2, self.config['nh'])])
        # attention
        e_t, alpha = self.attention.compute(feat, T.concatenate([e_tm1, h_tm1, word_vec], axis=1))
        # lstm step
        e_w_s = T.concatenate([e_t, word_vec, scene], axis=-1)
        c_t, h_t = self.lstm.compute(e_w_s, c_tm1, h_tm1)
        # merge state
        new_state = T.concatenate([e_t, c_t, h_t], axis=-1)
        # add w_{t-1} as feature
        e_h_w_s = T.concatenate([e_t, h_t, word_vec, scene], axis=-1)
        # predict probability
        p = self.pred_mlp.compute(e_h_w_s)
        return new_state, p, alpha
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号