def _read_record(self, filename_queue):
class FrameSeqRecord(object):
pass
record = FrameSeqRecord()
record.height = self._data_img_size[0]
record.width = self._data_img_size[1]
record.depth = self._data_img_size[2]
input_seq_length = self.input_shape[0]
target_seq_length = self.target_shape[0]
total_seq_length = input_seq_length + target_seq_length
frame_bytes = record.height * record.width * record.depth
record_bytes = frame_bytes * (total_seq_length)
total_file_bytes = frame_bytes * self._serialized_sequence_length
with tf.name_scope('read_record'):
reader = tf.FixedLengthRecordReader(total_file_bytes)
record.key, value = reader.read(filename_queue)
decoded_record_bytes = tf.decode_raw(value, tf.uint8)
decoded_record_bytes = tf.reshape(decoded_record_bytes,
[self._serialized_sequence_length, record.height, record.width, record.depth])
# calculcate tensors [start, 0, 0, 0]
rnd_start_index = tf.to_int32(tf.random_uniform([1], 0, self._serialized_sequence_length - (total_seq_length),
tf.int32))
seq_start_offset = tf.SparseTensor(indices=[[0]], values=rnd_start_index, dense_shape=[4])
sequence_start = tf.sparse_tensor_to_dense(seq_start_offset)
# take a random slice of frames as input
record.data = tf.slice(decoded_record_bytes, sequence_start,
[total_seq_length, record.height, record.width, record.depth])
return record
评论列表
文章目录