python类SparseTensorValue()的实例源码

model.py 文件源码 项目:deeplearning 作者: zxjzxj9 项目源码 文件源码 阅读 22 收藏 0 点赞 0 评论 0
def train(self):

        optimizer = tf.train.AdamOptimizer(learning_rate = self.config.learning_rate, \
                    beta1 = 0.9, beta2 = 0.999).minimize(self.loss)

        #grads = optimizer.compute_gradients(self.loss)
        #for i, (g, v) in enumerate(grads):
        #    if g is not None:
        #        grads[i] = (tf.clip_by_norm(g, 5), v)
        #train_op = optimizer.apply_gradients(grads)

        self.sess = tf.Session()
        s = self.sess

        writer = tf.summary.FileWriter("./log", graph = s.graph)
        tf.summary.scalar("loss", self.loss)


        merged_summary = tf.summary.merge_all()
        cnt_total = 0
        s.run(tf.global_variables_initializer())

        for epoch in range(self.config.epoch_num):
            print("In epoch %d " %epoch)
            cnt = 0

            for img, label, seq_len in self.datasrc.get_iter(16, self.config.batch_size):

                loss, _, summary = s.run([self.loss, optimizer, merged_summary], feed_dict = { \
                    self.input : img,
                    self.output : tf.SparseTensorValue(*label),
                    self.seq_len : [self.config.split_num]*len(seq_len),
                #    self.seq_len : seq_len,
                    self.keep_prob : 1.0,
                })

                #print("loss %f" %loss)

                writer.add_summary(summary, cnt_total)
                sys.stdout.write("Current loss: %.3e, current batch: %d \r" %(loss,cnt))
                cnt += 1
                cnt_total += 1

            if epoch % self.config.nsave == self.config.nsave - 1:
                self.saver.save(s, "./log/model_epoch_%d.ckpt" %(epoch + 1))
        print("")
transform.py 文件源码 项目:pydatalab 作者: googledatalab 项目源码 文件源码 阅读 25 收藏 0 点赞 0 评论 0
def process(self, element):
    """Run the transformation graph on batched input data

    Args:
      element: list of csv strings, representing one batch input to the TF graph.

    Returns:
      dict containing the transformed data. Results are un-batched. Sparse
      tensors are converted to lists.
    """
    import apache_beam as beam
    import six
    import tensorflow as tf

    # This function is invoked by a separate sub-process so setting the logging level
    # does not affect Datalab's kernel process.
    tf.logging.set_verbosity(tf.logging.ERROR)
    try:
      clean_element = []
      for line in element:
        clean_element.append(line.rstrip())

      # batch_result is list of numpy arrays with batch_size many rows.
      batch_result = self._session.run(
          fetches=self._transformed_features,
          feed_dict={self._input_placeholder_tensor: clean_element})

      # ex batch_result. 
      # Dense tensor: {'col1': array([[batch_1], [batch_2]])}
      # Sparse tensor: {'col1': tf.SparseTensorValue(
      #   indices=array([[batch_1, 0], [batch_1, 1], ...,
      #                  [batch_2, 0], [batch_2, 1], ...]],
      #   values=array[value, value, value, ...])}

      # Unbatch the results.
      for i in range(len(clean_element)):
        transformed_features = {}
        for name, value in six.iteritems(batch_result):
          if isinstance(value, tf.SparseTensorValue):
            batch_i_indices = value.indices[:, 0] == i
            batch_i_values = value.values[batch_i_indices]
            transformed_features[name] = batch_i_values.tolist()
          else:
            transformed_features[name] = value[i].tolist()

        yield transformed_features

    except Exception as e:  # pylint: disable=broad-except
      yield beam.pvalue.SideOutputValue('errors',
                                        (str(e), element))
transform.py 文件源码 项目:pydatalab 作者: googledatalab 项目源码 文件源码 阅读 38 收藏 0 点赞 0 评论 0
def process(self, element):
    """Run the transformation graph on batched input data

    Args:
      element: list of csv strings, representing one batch input to the TF graph.

    Returns:
      dict containing the transformed data. Results are un-batched. Sparse
      tensors are converted to lists.
    """
    import apache_beam as beam
    import six
    import tensorflow as tf

    # This function is invoked by a separate sub-process so setting the logging level
    # does not affect Datalab's kernel process.
    tf.logging.set_verbosity(tf.logging.ERROR)
    try:
      clean_element = []
      for line in element:
        clean_element.append(line.rstrip())

      # batch_result is list of numpy arrays with batch_size many rows.
      batch_result = self._session.run(
          fetches=self._transformed_features,
          feed_dict={self._input_placeholder_tensor: clean_element})

      # ex batch_result. 
      # Dense tensor: {'col1': array([[batch_1], [batch_2]])}
      # Sparse tensor: {'col1': tf.SparseTensorValue(
      #   indices=array([[batch_1, 0], [batch_1, 1], ...,
      #                  [batch_2, 0], [batch_2, 1], ...]],
      #   values=array[value, value, value, ...])}

      # Unbatch the results.
      for i in range(len(clean_element)):
        transformed_features = {}
        for name, value in six.iteritems(batch_result):
          if isinstance(value, tf.SparseTensorValue):
            batch_i_indices = value.indices[:, 0] == i
            batch_i_values = value.values[batch_i_indices]
            transformed_features[name] = batch_i_values.tolist()
          else:
            transformed_features[name] = value[i].tolist()

        yield transformed_features

    except Exception as e:  # pylint: disable=broad-except
      yield beam.pvalue.SideOutputValue('errors',
                                        (str(e), element))


问题


面经


文章

微信
公众号

扫码关注公众号