DMN.py 文件源码

python
阅读 26 收藏 0 点赞 0 评论 0

项目:DynamicMemoryNetworks 作者: swstarlab 项目源码 文件源码
def __init__(self, incomings, hid_state_size, voc_size,
                 resetgate  = GRU_Gate(), updategate = GRU_Gate(),
                 hid_update = GRU_Gate(nonlinearity=nonlin.tanh),
                 W=Normal(), max_answer_word=1, **kwargs):
        super(AnswerModule, self).__init__(incomings, **kwargs)

        self.hid_state_size = hid_state_size

        #FOR GRU
        input_shape = self.input_shapes[0]
        num_inputs = np.prod(input_shape[1]) + voc_size # concatenation of previous prediction

        def add_gate(gate, gate_name):
            return (self.add_param(gate.W_in, (num_inputs, hid_state_size),
                        name="W_in_to_{}".format(gate_name)),
                    self.add_param(gate.W_hid, (hid_state_size, hid_state_size),
                        name="W_hid_to_{}".format(gate_name)),
                    self.add_param(gate.b, (hid_state_size,),
                        name="b_{}".format(gate_name), regularizable=False), 
                    gate.nonlinearity)

        # Add in all parameters from gates
        (self.W_in_to_updategate,
         self.W_hid_to_updategate,
         self.b_updategate,
         self.nonlinearity_updategate)= add_gate(updategate, 'updategate')
        (self.W_in_to_resetgate,
         self.W_hid_to_resetgate,
         self.b_resetgate,
         self.nonlinearity_resetgate) = add_gate(resetgate, 'resetgate')
        (self.W_in_to_hid_update,
         self.W_hid_to_hid_update,
         self.b_hid_update,
         self.nonlinearity_hid)       = add_gate(hid_update, 'hid_update')

        self.W = self.add_param(W, (hid_state_size, voc_size), name="W")
        self.max_answer_word = max_answer_word

        self.rand_stream = RandomStreams(np.random.randint(1, 2147462579))
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号