def inference_graph(self, data):
with ops.device(self.device_assigner.get_device(self.layer_num)):
# Compute activations for the neural network.
nn_activations = [layers.fully_connected(data, self.params.layer_size)]
for _ in range(1, self.params.num_layers):
# pylint: disable=W0106
nn_activations.append(
layers.fully_connected(
nn_activations[-1],
self.params.layer_size))
nn_activations_tensor = array_ops.concat(
1, nn_activations, name="flattened_nn_activations")
return nn_activations_tensor
评论列表
文章目录