def _set_variable_and_publish(self, sess, iteration_id, transaction_id,
group_id):
# v = variable
# s = v.to_proto().SerializeToString()
# h = ':'.join('{:02x}'.format(ord(c)) for c in s)
variable_names = [var.op.name for var in self.variables]
g = sess.graph
g_def = g.as_graph_def()
constants = graph_util.convert_variables_to_constants(
sess, g_def, variable_names)
s = constants.SerializeToString()
parallel_count = self.infra_info['parallel_count']
self.rc.set(transaction_id, s)
message = json.dumps({
'key': 'set_variable',
'transaction_id': transaction_id,
'group_id': group_id,
'variables': variable_names,
'worker_id': self.worker_id,
'train_id': self.train_id,
'iteration_id': iteration_id,
'parallel_count': parallel_count
})
self.r.publish(channel=channel, message=message)
log.debug('pub %s' % iteration_id)
return len(s)
评论列表
文章目录