rnn.py 文件源码

python
阅读 38 收藏 0 点赞 0 评论 0

项目:kaggle-review 作者: daxiongshu 项目源码 文件源码
def _build(self):
        V = self.V
        M = self.flags.embedding_size # 64
        H = self.flags.num_units 
        C = self.flags.classes

        netname = "CBOW"
        with tf.variable_scope(netname):
            self.inputs = tf.placeholder(dtype=tf.int32,shape=[None, None]) #[B,S]
            layer_name = "{}/embedding".format(netname)
            x = self._get_embedding(layer_name, self.inputs, V, M, reuse=False) # [B, S, M]

        netname = "RNN"
        cell_name = self.flags.cell
        with tf.variable_scope(netname):
            args = {"num_units":H,"num_proj":C}
            cell_f = self._get_rnn_cell(cell_name=cell_name, args=args)
            cell_b = self._get_rnn_cell(cell_name=cell_name, args=args)
            (out_f, out_b), _ = tf.nn.bidirectional_dynamic_rnn(cell_f,cell_b,x,dtype=tf.float32)
            #logit = (out_f[:,-1,:] + out_b[:,-1,:])*0.5  # [B,1,C]
            logit = tf.reduce_mean(out_f+out_b,axis=1)
            logit = tf.squeeze(logit) # [B,C]
            self.logit = logit
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号