def _calculate_average_and_put(self, group_id, item, m):
keys = item['keys']
tf.reset_default_graph()
sess = tf.Session()
new_vars = []
m_cal_and_put = SimpleMeasurement('cal_and_put', m)
m_init = SimpleMeasurement('init', m)
init_op = tf.global_variables_initializer()
sess.run(init_op)
m_init.end_measure()
for v in item['variables']:
count = 0
name = 'average_%s' % v
ts = []
for key in keys:
raw = self.rc.get(key)
# TODO: check raw is not None
util.restore_graph(key, raw)
g = sess.graph
t = g.get_tensor_by_name('%s/%s:0' % (key, v))
ts.append(t)
count += 1
m_cal = SimpleMeasurement('cal', m)
avg = tf.foldl(tf.add, ts) / count
new_var = tf.Variable(avg, name=name)
sess.run(new_var.initializer)
sess.run(new_var)
new_vars.append(name)
m_cal.end_measure()
g = sess.graph
g_def = g.as_graph_def()
constants = graph_util.convert_variables_to_constants(
sess, g_def, new_vars)
s = constants.SerializeToString()
self.rc.set(group_id, s)
sess.close()
m_cal_and_put.end_measure()
评论列表
文章目录