def __call__(self, v, h, label):
v_t = self.vertical_conv_t(v)
v_s = self.vertical_conv_s(v)
to_vertical_t = self.v_to_h_conv_t(v_t)
to_vertical_s = self.v_to_h_conv_s(v_s)
# v_gate = self.vertical_gate_conv(v)
# label bias is added to both vertical and horizontal conv
# here we take only shape as it should be the same
label = F.broadcast_to(F.expand_dims(F.expand_dims(self.label(label), -1), -1), v_t.shape)
v_t, v_s = v_t + label, v_s + label
v = F.tanh(v_t) * F.sigmoid(v_s)
h_t = self.horizontal_conv_t(h)
h_s = self.horizontal_conv_s(h)
h_t, h_s = h_t + to_vertical_t + label, h_s + to_vertical_s + label
h = self.horizontal_output(F.tanh(h_t) * F.sigmoid(h_s))
return v, h
评论列表
文章目录