def __init__(self, fea_size, dropout=False, gate_width=128, use_region=True, use_kernel_function=False):
super(Hierarchical_Message_Passing_Structure_base, self).__init__()
#self.w_object = Parameter()
if use_kernel_function:
Message_Passing_Unit = Message_Passing_Unit_v2
else:
Message_Passing_Unit = Message_Passing_Unit_v1
self.gate_sub2pred = Message_Passing_Unit(fea_size, gate_width)
self.gate_obj2pred = Message_Passing_Unit(fea_size, gate_width)
self.gate_pred2sub = Message_Passing_Unit(fea_size, gate_width)
self.gate_pred2obj = Message_Passing_Unit(fea_size, gate_width)
self.GRU_object = Gated_Recurrent_Unit(fea_size, dropout) # nn.GRUCell(fea_size, fea_size) #
self.GRU_phrase = Gated_Recurrent_Unit(fea_size, dropout)
if use_region:
self.gate_pred2reg = Message_Passing_Unit(fea_size, gate_width)
self.gate_reg2pred = Message_Passing_Unit(fea_size, gate_width)
self.GRU_region = Gated_Recurrent_Unit(fea_size, dropout)
评论列表
文章目录