def fwd_split(self, input, batch, depth,
random_split=False, mode='train', epoch=0):
length = self.split.n
var = 0.0
# Iterate over scales
e = Variable(torch.zeros(self.batch_size, length)).type(dtype)
mask = (input[:, :, 0] >= 0).type(dtype).squeeze()
Phis, Bs, Inputs_N, Samples = ([] for ii in xrange(4))
for scale in xrange(depth):
logits, probs, input_n, Phi = self.split(e, input,
mask, scale=scale)
# Sample from probabilities and update embeddings
if random_split:
rand = (Variable(torch.zeros(self.batch_size, length))
.type(dtype))
init.uniform(rand)
sample = (rand > 0.5).type(dtype)
else:
rand = (Variable(torch.zeros(self.batch_size, length))
.type(dtype))
init.uniform(rand)
sample = (probs > rand).type(dtype)
e = 2 * e + sample
# Appends
Samples.append(sample)
Phis.append(Phi)
Bs.append(probs)
Inputs_N.append(input_n)
# variance of bernouilli probabilities
var += self.compute_variance(probs, mask)
# computes log probabilities of binary actions for the policy gradient
Log_Probs = self.log_probabilities(Bs, Samples, mask, depth)
# pad embeddings with infinity to not affect embeddings argsort
infty = 1e6
e = e * mask + (1 - mask) * infty
return var, Phis, Bs, Inputs_N, e, Log_Probs
###########################################################################
# Merge Phase #
###########################################################################
评论列表
文章目录