def train(self):
optimizer = tf.train.AdamOptimizer(learning_rate = self.config.learning_rate, \
beta1 = 0.9, beta2 = 0.999).minimize(self.loss)
#grads = optimizer.compute_gradients(self.loss)
#for i, (g, v) in enumerate(grads):
# if g is not None:
# grads[i] = (tf.clip_by_norm(g, 5), v)
#train_op = optimizer.apply_gradients(grads)
self.sess = tf.Session()
s = self.sess
writer = tf.summary.FileWriter("./log", graph = s.graph)
tf.summary.scalar("loss", self.loss)
merged_summary = tf.summary.merge_all()
cnt_total = 0
s.run(tf.global_variables_initializer())
for epoch in range(self.config.epoch_num):
print("In epoch %d " %epoch)
cnt = 0
for img, label, seq_len in self.datasrc.get_iter(16, self.config.batch_size):
loss, _, summary = s.run([self.loss, optimizer, merged_summary], feed_dict = { \
self.input : img,
self.output : tf.SparseTensorValue(*label),
self.seq_len : [self.config.split_num]*len(seq_len),
# self.seq_len : seq_len,
self.keep_prob : 1.0,
})
#print("loss %f" %loss)
writer.add_summary(summary, cnt_total)
sys.stdout.write("Current loss: %.3e, current batch: %d \r" %(loss,cnt))
cnt += 1
cnt_total += 1
if epoch % self.config.nsave == self.config.nsave - 1:
self.saver.save(s, "./log/model_epoch_%d.ckpt" %(epoch + 1))
print("")
python类SparseTensorValue()的实例源码
def process(self, element):
"""Run the transformation graph on batched input data
Args:
element: list of csv strings, representing one batch input to the TF graph.
Returns:
dict containing the transformed data. Results are un-batched. Sparse
tensors are converted to lists.
"""
import apache_beam as beam
import six
import tensorflow as tf
# This function is invoked by a separate sub-process so setting the logging level
# does not affect Datalab's kernel process.
tf.logging.set_verbosity(tf.logging.ERROR)
try:
clean_element = []
for line in element:
clean_element.append(line.rstrip())
# batch_result is list of numpy arrays with batch_size many rows.
batch_result = self._session.run(
fetches=self._transformed_features,
feed_dict={self._input_placeholder_tensor: clean_element})
# ex batch_result.
# Dense tensor: {'col1': array([[batch_1], [batch_2]])}
# Sparse tensor: {'col1': tf.SparseTensorValue(
# indices=array([[batch_1, 0], [batch_1, 1], ...,
# [batch_2, 0], [batch_2, 1], ...]],
# values=array[value, value, value, ...])}
# Unbatch the results.
for i in range(len(clean_element)):
transformed_features = {}
for name, value in six.iteritems(batch_result):
if isinstance(value, tf.SparseTensorValue):
batch_i_indices = value.indices[:, 0] == i
batch_i_values = value.values[batch_i_indices]
transformed_features[name] = batch_i_values.tolist()
else:
transformed_features[name] = value[i].tolist()
yield transformed_features
except Exception as e: # pylint: disable=broad-except
yield beam.pvalue.SideOutputValue('errors',
(str(e), element))
def process(self, element):
"""Run the transformation graph on batched input data
Args:
element: list of csv strings, representing one batch input to the TF graph.
Returns:
dict containing the transformed data. Results are un-batched. Sparse
tensors are converted to lists.
"""
import apache_beam as beam
import six
import tensorflow as tf
# This function is invoked by a separate sub-process so setting the logging level
# does not affect Datalab's kernel process.
tf.logging.set_verbosity(tf.logging.ERROR)
try:
clean_element = []
for line in element:
clean_element.append(line.rstrip())
# batch_result is list of numpy arrays with batch_size many rows.
batch_result = self._session.run(
fetches=self._transformed_features,
feed_dict={self._input_placeholder_tensor: clean_element})
# ex batch_result.
# Dense tensor: {'col1': array([[batch_1], [batch_2]])}
# Sparse tensor: {'col1': tf.SparseTensorValue(
# indices=array([[batch_1, 0], [batch_1, 1], ...,
# [batch_2, 0], [batch_2, 1], ...]],
# values=array[value, value, value, ...])}
# Unbatch the results.
for i in range(len(clean_element)):
transformed_features = {}
for name, value in six.iteritems(batch_result):
if isinstance(value, tf.SparseTensorValue):
batch_i_indices = value.indices[:, 0] == i
batch_i_values = value.values[batch_i_indices]
transformed_features[name] = batch_i_values.tolist()
else:
transformed_features[name] = value[i].tolist()
yield transformed_features
except Exception as e: # pylint: disable=broad-except
yield beam.pvalue.SideOutputValue('errors',
(str(e), element))