mps_base.py 文件源码

python
阅读 28 收藏 0 点赞 0 评论 0

项目:MSDN 作者: yikang-li 项目源码 文件源码
def __init__(self, fea_size, dropout=False, gate_width=128, use_region=True, use_kernel_function=False):
        super(Hierarchical_Message_Passing_Structure_base, self).__init__()
        #self.w_object = Parameter()
        if use_kernel_function:
            Message_Passing_Unit = Message_Passing_Unit_v2
        else:
            Message_Passing_Unit = Message_Passing_Unit_v1

        self.gate_sub2pred = Message_Passing_Unit(fea_size, gate_width) 
        self.gate_obj2pred = Message_Passing_Unit(fea_size, gate_width) 
        self.gate_pred2sub = Message_Passing_Unit(fea_size, gate_width) 
        self.gate_pred2obj = Message_Passing_Unit(fea_size, gate_width) 

        self.GRU_object = Gated_Recurrent_Unit(fea_size, dropout) # nn.GRUCell(fea_size, fea_size) #
        self.GRU_phrase = Gated_Recurrent_Unit(fea_size, dropout)

        if use_region:
            self.gate_pred2reg = Message_Passing_Unit(fea_size, gate_width) 
            self.gate_reg2pred = Message_Passing_Unit(fea_size, gate_width) 
            self.GRU_region = Gated_Recurrent_Unit(fea_size, dropout)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号