def forward(self, x):
for name, module in self.base._modules.items():
if name == 'avgpool':
break
x = module(x)
if self.cut_at_pooling:
return x
x = F.avg_pool2d(x, x.size()[2:])
x = x.view(x.size(0), -1)
if self.has_embedding:
x = self.feat(x)
x = self.feat_bn(x)
if self.norm:
x = F.normalize(x)
elif self.has_embedding:
x = F.relu(x)
if self.dropout > 0:
x = self.drop(x)
if self.num_classes > 0:
x = self.classifier(x)
return x
python类normalize()的实例源码
def predict(self, x):
batch_size, dims = x.size()
query = F.normalize(self.query_proj(x), dim=1)
# Find the k-nearest neighbors of the query
scores = torch.matmul(query, torch.t(self.keys_var))
cosine_similarity, topk_indices_var = torch.topk(scores, self.top_k, dim=1)
# softmax of cosine similarities - embedding
softmax_score = F.softmax(self.softmax_temperature * cosine_similarity)
# retrive memory values - prediction
y_hat_indices = topk_indices_var.data[:, 0]
y_hat = self.values[y_hat_indices]
return y_hat, softmax_score
def normalize(w):
"""Normalizes weight tensor over full filter."""
return F.normalize(w.view(w.size(0), -1)).view_as(w)
def forward(self, input):
return F.conv1d(input, self.alpha * Variable(self.delta) + self.beta * normalize(self.weight),
self.bias, self.stride, self.padding, self.dilation)
def forward(self, input):
return F.conv2d(input, self.alpha * Variable(self.delta) + self.beta * normalize(self.weight),
self.bias, self.stride, self.padding, self.dilation)
def forward(self, input):
return F.conv3d(input, self.alpha * Variable(self.delta) + self.beta * normalize(self.weight),
self.bias, self.stride, self.padding, self.dilation)
def block(o, params, stats, base, mode, j):
w = params[base + '.conv']
alpha = params[base + '.alpha']
beta = params[base + '.beta']
delta = Variable(stats[size2name(w.size())])
w = beta * F.normalize(w.view(w.size(0), -1)).view_as(w) + alpha * delta
o = F.conv2d(ncrelu(o), w, stride=1, padding=1)
o = batch_norm(o, params, stats, base + '.bn', mode)
return o
def forward(self, input_enc, input_attW_enc, input_dec, lengths_enc, hidden_att=None, hidden_dec1=None, hidden_dec2=None):
N = input_dec.size(0)
out_att = self.prenet(input_dec).unsqueeze(1) # N x O_dec -> N x 1 x H
out_att, hidden_att = self.gru_att(out_att, hidden_att) # N x 1 x 2H
in_attW_dec = self.linear_dec(out_att.squeeze(1)).unsqueeze(1).expand_as(input_enc)
in_attW_dec = rnn.pack_padded_sequence(in_attW_dec, lengths_enc, True) # N*T_enc x 2H
self.attn_weights = torch.add(input_attW_enc, in_attW_dec.data).tanh() # N x T_enc x 2H
self.attn_weights = self.attn(self.attn_weights).exp() # N*T_enc x 1
self.attn_weights = rnn.PackedSequence(self.attn_weights, in_attW_dec.batch_sizes)
self.attn_weights, _ = rnn.pad_packed_sequence(self.attn_weights, True)
self.attn_weights = F.normalize(self.attn_weights, 1, 1) # N x T_enc x 1
attn_applied = torch.bmm(self.attn_weights.transpose(1,2), input_enc) # N x 1 x 2H
out_dec = torch.cat((attn_applied, out_att), 2) # N x 1 x 4H
residual = self.short_cut(out_dec.squeeze(1)).unsqueeze(1) # N x 1 x 2H
out_dec, hidden_dec1 = self.gru_dec1(out_dec, hidden_dec1)
residual = residual + out_dec
out_dec, hidden_dec2 = self.gru_dec2(residual, hidden_dec2)
residual = residual + out_dec
output = self.out(residual.squeeze(1)).view(N, self.r_factor, -1)
return output, hidden_att, hidden_dec1, hidden_dec2
def forward(self, x):
x = self.conv1(x)
x = self.conv2(x)
x = self.conv3(x)
x = self.pool3(x)
x = self.inception4a(x)
x = self.inception4b(x)
x = self.inception5a(x)
x = self.inception5b(x)
x = self.inception6a(x)
x = self.inception6b(x)
if self.cut_at_pooling:
return x
x = self.avgpool(x)
x = x.view(x.size(0), -1)
if self.has_embedding:
x = self.feat(x)
x = self.feat_bn(x)
if self.norm:
x = F.normalize(x)
elif self.has_embedding:
x = F.relu(x)
if self.dropout > 0:
x = self.drop(x)
if self.num_classes > 0:
x = self.classifier(x)
return x
def forward(self, input):
ang = self._backend.Linear.apply(F.normalize(input), F.normalize(self.weight))
ang = ang.clamp(-1, 1).acos()
xnorm = input.norm(p=2, dim=1)
return ang, xnorm
def forward(self, input):
ang = self._backend.Linear.apply(F.normalize(input), F.normalize(self.weight))
ang = ang.clamp(-1, 1).acos()
xnorm = input.norm(p=2, dim=1)
return ang, xnorm
def test_normalize(self):
inputs = Variable(torch.randn(1, 3, 4, 4), requires_grad=True)
self.assertTrue(gradcheck(lambda x: F.normalize(x, p=1, dim=-1), (inputs,)))
self.assertTrue(gradcheck(lambda x: F.normalize(x, p=2, dim=-2), (inputs,)))
def test_normalize(self):
inputs = Variable(torch.randn(1, 3, 4, 4), requires_grad=True)
self.assertTrue(gradcheck(lambda x: F.normalize(x, p=1, dim=-1), (inputs,)))
self.assertTrue(gradcheck(lambda x: F.normalize(x, p=2, dim=-2), (inputs,)))
def forward(self, x):
if isinstance(x, Variable):
return F.normalize(x, self.p, self.dim, eps=1e-10)
elif isinstance(x, tuple) or isinstance(x, list):
return my_data_parallel(self, x)
else:
raise RuntimeError('unknown input type')
def at(x):
return F.normalize(x.pow(2).mean(1).view(x.size(0), -1))
def test_normalize(self):
inputs = Variable(torch.randn(1, 3, 4, 4), requires_grad=True)
self.assertTrue(gradcheck(lambda x: F.normalize(x, p=1, dim=-1), (inputs,)))
self.assertTrue(gradcheck(lambda x: F.normalize(x, p=2, dim=-2), (inputs,)))
def build(self):
self.keys = F.normalize(random_uniform((self.memory_size, self.key_dim), -0.001, 0.001, cuda=True), dim=1)
self.keys_var = ag.Variable(self.keys, requires_grad=False)
self.values = torch.zeros(self.memory_size, 1).long().cuda()
self.age = torch.zeros(self.memory_size, 1).cuda()
def update(self, query, y, y_hat, y_hat_indices):
batch_size, dims = query.size()
# 1) Untouched: Increment memory by 1
self.age += 1
# Divide batch by correctness
result = torch.squeeze(torch.eq(y_hat, torch.unsqueeze(y.data, dim=1))).float()
incorrect_examples = torch.squeeze(torch.nonzero(1-result))
correct_examples = torch.squeeze(torch.nonzero(result))
incorrect = len(incorrect_examples.size()) > 0
correct = len(correct_examples.size()) > 0
# 2) Correct: if V[n1] = v
# Update Key k[n1] <- normalize(q + K[n1]), Reset Age A[n1] <- 0
if correct:
correct_indices = y_hat_indices[correct_examples]
correct_keys = self.keys[correct_indices]
correct_query = query.data[correct_examples]
new_correct_keys = F.normalize(correct_keys + correct_query, dim=1)
self.keys[correct_indices] = new_correct_keys
self.age[correct_indices] = 0
# 3) Incorrect: if V[n1] != v
# Select item with oldest age, Add random offset - n' = argmax_i(A[i]) + r_i
# K[n'] <- q, V[n'] <- v, A[n'] <- 0
if incorrect:
incorrect_size = incorrect_examples.size()[0]
incorrect_query = query.data[incorrect_examples]
incorrect_values = y.data[incorrect_examples]
age_with_noise = self.age + random_uniform((self.memory_size, 1), -self.age_noise, self.age_noise, cuda=True)
topk_values, topk_indices = torch.topk(age_with_noise, incorrect_size, dim=0)
oldest_indices = torch.squeeze(topk_indices)
self.keys[oldest_indices] = incorrect_query
self.values[oldest_indices] = incorrect_values
self.age[oldest_indices] = 0
def query(self, x, y, predict=False):
"""
Compute the nearest neighbor of the input queries.
Arguments:
x: A normalized matrix of queries of size (batch_size x key_dim)
y: A matrix of correct labels (batch_size x 1)
Returns:
y_hat, A (batch-size x 1) matrix
- the nearest neighbor to the query in memory_size
softmax_score, A (batch_size x 1) matrix
- A normalized score measuring the similarity between query and nearest neighbor
loss - average loss for memory module
"""
batch_size, dims = x.size()
query = F.normalize(self.query_proj(x), dim=1)
#query = F.normalize(torch.matmul(x, self.query_proj), dim=1)
# Find the k-nearest neighbors of the query
scores = torch.matmul(query, torch.t(self.keys_var))
cosine_similarity, topk_indices_var = torch.topk(scores, self.top_k, dim=1)
# softmax of cosine similarities - embedding
softmax_score = F.softmax(self.softmax_temperature * cosine_similarity)
# retrive memory values - prediction
topk_indices = topk_indices_var.detach().data
y_hat_indices = topk_indices[:, 0]
y_hat = self.values[y_hat_indices]
loss = None
if not predict:
# Loss Function
# topk_indices = (batch_size x topk)
# topk_values = (batch_size x topk x value_size)
# collect the memory values corresponding to the topk scores
batch_size, topk_size = topk_indices.size()
flat_topk = flatten(topk_indices)
flat_topk_values = self.values[topk_indices]
topk_values = flat_topk_values.resize_(batch_size, topk_size)
correct_mask = torch.eq(topk_values, torch.unsqueeze(y.data, dim=1)).float()
correct_mask_var = ag.Variable(correct_mask, requires_grad=False)
pos_score, pos_idx = torch.topk(torch.mul(cosine_similarity, correct_mask_var), 1, dim=1)
neg_score, neg_idx = torch.topk(torch.mul(cosine_similarity, 1-correct_mask_var), 1, dim=1)
# zero-out correct scores if there are no correct values in topk values
mask = 1.0 - torch.eq(torch.sum(correct_mask_var, dim=1), 0.0).float()
pos_score = torch.mul(pos_score, torch.unsqueeze(mask, dim=1))
#print(pos_score, neg_score)
loss = MemoryLoss(pos_score, neg_score, self.margin)
# Update memory
self.update(query, y, y_hat, y_hat_indices)
return y_hat, softmax_score, loss