def forward(self, anchor, positive, negative):
#eucl distance
#dist = torch.sum( (anchor - positive) ** 2 - (anchor - negative) ** 2, dim=1)\
# + self.margin
if self.dist_type == 0:
dist_p = F.pairwise_distance(anchor ,positive)
dist_n = F.pairwise_distance(anchor ,negative)
if self.dist_type == 1:
dist_p = cosine_similarity(anchor, positive)
disp_n = cosine_similarity(anchor, negative)
dist_hinge = torch.clamp(dist_p - dist_n + self.margin, min=0.0)
if self.use_ohem:
v, idx = torch.sort(dist_hinge,descending=True)
loss = torch.mean(v[0:self.ohem_bs])
else:
loss = torch.mean(dist_hinge)
return loss
python类pairwise_distance()的实例源码
def _algo_1_horiz_comp(self, sent1_block_a, sent2_block_a):
comparison_feats = []
for pool in ('max', 'min', 'mean'):
for ws in self.filter_widths:
x1 = sent1_block_a[ws][pool]
x2 = sent2_block_a[ws][pool]
batch_size = x1.size()[0]
comparison_feats.append(F.cosine_similarity(x1, x2).contiguous().view(batch_size, 1))
comparison_feats.append(F.pairwise_distance(x1, x2))
return torch.cat(comparison_feats, dim=1)
def _algo_2_vert_comp(self, sent1_block_a, sent2_block_a, sent1_block_b, sent2_block_b):
comparison_feats = []
ws_no_inf = [w for w in self.filter_widths if not np.isinf(w)]
for pool in ('max', 'min', 'mean'):
for ws1 in self.filter_widths:
x1 = sent1_block_a[ws1][pool]
batch_size = x1.size()[0]
for ws2 in self.filter_widths:
x2 = sent2_block_a[ws2][pool]
if (not np.isinf(ws1) and not np.isinf(ws2)) or (np.isinf(ws1) and np.isinf(ws2)):
comparison_feats.append(F.cosine_similarity(x1, x2).contiguous().view(batch_size, 1))
comparison_feats.append(F.pairwise_distance(x1, x2))
comparison_feats.append(torch.abs(x1 - x2))
for pool in ('max', 'min'):
for ws in ws_no_inf:
oG_1B = sent1_block_b[ws][pool]
oG_2B = sent2_block_b[ws][pool]
for i in range(0, self.n_per_dim_filters):
x1 = oG_1B[:, :, i]
x2 = oG_2B[:, :, i]
batch_size = x1.size()[0]
comparison_feats.append(F.cosine_similarity(x1, x2).contiguous().view(batch_size, 1))
comparison_feats.append(F.pairwise_distance(x1, x2))
comparison_feats.append(torch.abs(x1 - x2))
return torch.cat(comparison_feats, dim=1)
tripletnet.py 文件源码
项目:conditional-similarity-networks
作者: andreasveit
项目源码
文件源码
阅读 25
收藏 0
点赞 0
评论 0
def forward(self, x, y, z, c):
""" x: Anchor image,
y: Distant (negative) image,
z: Close (positive) image,
c: Integer indicating according to which notion of similarity images are compared"""
embedded_x, masknorm_norm_x, embed_norm_x, tot_embed_norm_x = self.embeddingnet(x, c)
embedded_y, masknorm_norm_y, embed_norm_y, tot_embed_norm_y = self.embeddingnet(y, c)
embedded_z, masknorm_norm_z, embed_norm_z, tot_embed_norm_z = self.embeddingnet(z, c)
mask_norm = (masknorm_norm_x + masknorm_norm_y + masknorm_norm_z) / 3
embed_norm = (embed_norm_x + embed_norm_y + embed_norm_z) / 3
mask_embed_norm = (tot_embed_norm_x + tot_embed_norm_y + tot_embed_norm_z) / 3
dist_a = F.pairwise_distance(embedded_x, embedded_y, 2)
dist_b = F.pairwise_distance(embedded_x, embedded_z, 2)
return dist_a, dist_b, mask_norm, embed_norm, mask_embed_norm
def test_pairwise_distance(self):
input1 = Variable(torch.randn(4, 4), requires_grad=True)
input2 = Variable(torch.randn(4, 4), requires_grad=True)
self.assertTrue(gradcheck(lambda x, y: F.pairwise_distance(x, y), (input1, input2)))
def test_pairwise_distance(self):
input1 = Variable(torch.randn(4, 4), requires_grad=True)
input2 = Variable(torch.randn(4, 4), requires_grad=True)
self.assertTrue(gradcheck(lambda x, y: F.pairwise_distance(x, y), (input1, input2)))
def test_pairwise_distance(self):
input1 = Variable(torch.randn(4, 4), requires_grad=True)
input2 = Variable(torch.randn(4, 4), requires_grad=True)
self.assertTrue(gradcheck(lambda x, y: F.pairwise_distance(x, y), (input1, input2)))
def test_pairwise_distance(self):
input1 = Variable(torch.randn(4, 4), requires_grad=True)
input2 = Variable(torch.randn(4, 4), requires_grad=True)
self.assertTrue(gradcheck(lambda x, y: F.pairwise_distance(x, y), (input1, input2)))
def forward(self, x, y, z):
embedded_x = self.embeddingnet(x)
embedded_y = self.embeddingnet(y)
embedded_z = self.embeddingnet(z)
dist_a = F.pairwise_distance(embedded_x, embedded_y, 2)
dist_b = F.pairwise_distance(embedded_x, embedded_z, 2)
return dist_a, dist_b, embedded_x, embedded_y, embedded_z
def test(model, test_loader, epoch, margin, threshlod, is_cuda=True, log_interval=1000):
model.eval()
test_loss = AverageMeter()
accuracy = 0
num_p = 0
total_num = 0
batch_num = len(test_loader)
for batch_idx, (data_a, data_p, data_n,target) in enumerate(test_loader):
if is_cuda:
data_a = data_a.cuda()
data_p = data_p.cuda()
data_n = data_n.cuda()
target = target.cuda()
data_a = Variable(data_a, volatile=True)
data_p = Variable(data_p, volatile=True)
data_n = Variable(data_n, volatile=True)
target = Variable(target)
out_a = model(data_a)
out_p = model(data_p)
out_n = model(data_n)
loss = F.triplet_margin_loss(out_a,out_p,out_n, margin)
dist1 = F.pairwise_distance(out_a,out_p)
dist2 = F.pairwise_distance(out_a,out_n)
num = ((dist1 < threshlod).float().sum() + (dist2 > threshlod).float().sum()).data[0]
num_p += num
num_p = 1.0 * num_p
total_num += data_a.size()[0] * 2
#print('num--num_p -- total', num, num_p , total_num)
test_loss.update(loss.data[0])
if (batch_idx + 1) % log_interval == 0:
accuracy_tmp = num_p / total_num
print('Test- Epoch {:04d}\tbatch:{:06d}/{:06d}\tAccuracy:{:.04f}\tloss:{:06f}'\
.format(epoch, batch_idx+1, batch_num, accuracy_tmp, test_loss.avg))
test_loss.reset()
accuracy = num_p / total_num
return accuracy
def test_vis(model, test_loader, model_path, threshlod,\
margin=1.0, is_cuda=True, output_dir='output',is_visualization=True):
if not model_path is None:
model.load_full_weights(model_path)
print('loaded model file: {:s}'.format(model_path))
if is_cuda:
model = model.cuda()
model.eval()
test_loss = AverageMeter()
accuracy = 0
num_p = 0
total_num = 0
batch_num = len(test_loader)
for batch_idx, (data_a, data_p, data_n,target, img_paths) in enumerate(test_loader):
#for batch_idx, (data_a, data_p, data_n, target) in enumerate(test_loader):
if is_cuda:
data_a = data_a.cuda()
data_p = data_p.cuda()
data_n = data_n.cuda()
target = target.cuda()
data_a = Variable(data_a, volatile=True)
data_p = Variable(data_p, volatile=True)
data_n = Variable(data_n, volatile=True)
target = Variable(target)
out_a = model(data_a)
out_p = model(data_p)
out_n = model(data_n)
loss = F.triplet_margin_loss(out_a,out_p,out_n, margin)
dist1 = F.pairwise_distance(out_a,out_p)
dist2 = F.pairwise_distance(out_a,out_n)
batch_size = data_a.size()[0]
pos_flag = (dist1 <= threshlod).float()
neg_flag = (dist2 > threshlod).float()
if is_visualization:
for k in torch.arange(0, batch_size):
k = int(k)
if pos_flag[k].data[0] == 0:
combine_and_save(img_paths[0][k], img_paths[1][k], dist1[k], output_dir, '1-1')
if neg_flag[k].data[0] == 0:
combine_and_save(img_paths[0][k], img_paths[2][k], dist2[k], output_dir, '1-0')
num = (pos_flag.sum() + neg_flag.sum()).data[0]
print('{:f}, {:f}, {:f}'.format(num, pos_flag.sum().data[0], neg_flag.sum().data[0]))
num_p += num
total_num += data_a.size()[0] * 2
print('num_p = {:f}, total = {:f}'.format(num_p, total_num))
print('dist1 = {:f}, dist2 = {:f}'.format(dist1[0].data[0], dist2[0].data[0]))
accuracy = num_p / total_num
return accuracy
def best_test(model, _loader, model_path=None, is_cuda=True):
if not model_path is None:
model.load_full_weights(model_path)
print('loaded model file: {:s}'.format(model_path))
if is_cuda:
model = model.cuda()
model.eval()
total_num = 0
batch_num = len(_loader)
for batch_idx, (data_a, data_p, data_n,target) in enumerate(_loader):
if is_cuda:
data_a = data_a.cuda()
data_p = data_p.cuda()
data_n = data_n.cuda()
target = target.cuda()
data_a = Variable(data_a, volatile=True)
data_p = Variable(data_p, volatile=True)
data_n = Variable(data_n, volatile=True)
target = Variable(target)
out_a = model(data_a)
out_p = model(data_p)
out_n = model(data_n)
current_d_a_p = F.pairwise_distance(out_a,out_p)
current_d_a_n = F.pairwise_distance(out_a,out_n)
if batch_idx == 0:
d_a_p = current_d_a_p
d_a_n = current_d_a_n
else:
d_a_p = torch.cat((d_a_p, current_d_a_p), 0)
d_a_n = torch.cat((d_a_n, current_d_a_n), 0)
total_num += 2*data_a.size()[0]
mean_d_a_p = d_a_p.mean().data[0]
mean_d_a_n = d_a_n.mean().data[0]
start = min(mean_d_a_n, mean_d_a_p)
end = max(mean_d_a_n, mean_d_a_p)
best_thre = 0
best_num = 0
thre_step = 0.05
for val in torch.arange(start, end+thre_step, thre_step):
num = (((d_a_p <= val).float()).sum() + (d_a_n > val).float().sum()).data[0]
#print(num, val)
if num > best_num:
best_num = num
best_thre = val
return best_thre, best_num/total_num, mean_d_a_p, mean_d_a_n