utils.py 文件源码

python
阅读 28 收藏 0 点赞 0 评论 0

项目:AssociativeRetrieval 作者: jxwufan 项目源码 文件源码
def read(self, shuffle=True, num_epochs=None):
    with tf.name_scope('input'):
      reader = tf.TFRecordReader()
      filename_queue = tf.train.string_input_producer([self.filename], num_epochs=num_epochs)
      _, serialized_input = reader.read(filename_queue)
      inputs = tf.parse_single_example(serialized_input,
                                       features={
                                       'inputs_seq': tf.FixedLenFeature([self.seq_len * 2 + 3], tf.int64),
                                       'output': tf.FixedLenFeature([1], tf.int64)
                                       })
      inputs_seq = inputs['inputs_seq']
      output = inputs['output']
      min_after_dequeue = 100
      if shuffle:
        inputs_seqs, outputs = tf.train.shuffle_batch([inputs_seq, output], batch_size=self.batch_size, num_threads=2, capacity=min_after_dequeue + 3 * self.batch_size, min_after_dequeue=min_after_dequeue)
      else:
        inputs_seqs, outputs = tf.train.batch([inputs_seq, output], batch_size=self.batch_size)
      return inputs_seqs, outputs
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号