def construct(self):
net_opts = VGG_Face.OPTS()
net_opts.network_name = 'vgg_face_net'
net_opts.weight_path = 'pretrained/vgg-face.mat'
net_opts.apply_dropout = True
self.vgg_net = VGG_Face(net_opts)
x_normalized = self.vgg_net.normalize_input(self.x)
self.vgg_net.network(x_normalized)
self.keep_prob = self.vgg_net.keep_prob
self.embedded = self.vgg_net.fc6 #Fine tuning from FC6 of VGG-Face
self.embedded_with_class = tf.gather(self.embedded, self.with_class_idx, name='embedded_with_class')
self.dom_network(self.embedded)
self.class_network(self.embedded_with_class)
self.loss = self.dom_loss + self.class_loss
# Construct Domain Discriminator Network
评论列表
文章目录