def process_batch(self, batch):
"""
Execution of an update step, infer cg_id from selectors, and pick
corresponding computational graph, and apply batch to the CG.
"""
cg_id = self.get_cg_id_from_selectors(batch['src_selector'][0],
batch['trg_selector'][0])
# Apply input replacement with <UNK> if necessary
if self.drop_input[cg_id] > 0.0:
num_els = numpy.prod(batch['source'].shape)
num_reps = max(1, int(num_els * self.drop_input[cg_id]))
replace_idx = numpy.random.choice(num_els, num_reps, replace=False)
# TODO: set it according to unk_id in config
batch['source'][numpy.unravel_index(
replace_idx, batch['source'].shape)] = 1
ordered_batch = [batch[v.name] for v in self.algorithms[cg_id].inputs]
# To save memory, we may combine f_update and f_grad_shared
if self.f_grad_shareds[cg_id] is None:
inps = [self.learning_rate] + ordered_batch
cost = self.f_updates[cg_id](*inps)
self._cost = ('cost_' + cg_id, cost)
else:
cost = self.f_grad_shareds[cg_id](*ordered_batch)
self._cost = ('cost_' + cg_id, cost)
self.f_updates[cg_id](self.learning_rate)
评论列表
文章目录