data_reader.py 文件源码

python
阅读 53 收藏 0 点赞 0 评论 0

项目:Quantum_machine_learning 作者: kchng 项目源码 文件源码
def next_dose(self, batch_size = 50) :

            def convert_to_one_hot( label ) :
                label_one_hot = np.zeros((len(label),2))
                for i in range(len(label)) :
                    label_one_hot[i,label[i]] = 1
                return label_one_hot

            start = self._index_in_datafile
            if ( self._file_index == self.start_file_index ) and ( start == 0 ) :
                self.batch_size = batch_size
                while np.modf(float(self.nrows)/self.batch_size)[0] > 0.0 :
                     print 'Warning! Number of data per file/ dose size must be an integer.'
                     print 'number of data per file: %d' % self.nrows
                     print 'dose size: %d'               % self.batch_size
                     self.batch_size = int(input('Input new dose size: '))
                print 'dose size : %d'    % self.batch_size
                print 'number of data: %d' % self._ndata
                # Read in one file at a time
                data = np.genfromtxt(self.full_file_path%(self._file_index) ,dtype=int,
                       skip_header=0, skip_footer=0)
                self._images = data[:,:-1].astype('int')
                labels = data[:,-1:].astype('int')
                if self.convert_to_one_hot :
                    self._labels = convert_to_one_hot(labels)

            self._index_in_datafile += self.batch_size
            if self._index_in_datafile > self.nrows :
                self._file_index += 1
                start = 0
                self._index_in_datafile = self.batch_size
                assert self.batch_size <= self.nrows
                # Read in one file at a time
                data = np.genfromtxt(self.full_file_path%(self._file_index) ,dtype=int,
                       skip_header=0, skip_footer=0)
                self._images = data[:,:-1].astype('int')
                labels = data[:,-1:].astype('int')
                if self.convert_to_one_hot :
                    self._labels = convert_to_one_hot(labels)
                # Shufle data
                random.shuffle(self.shuffle_index_dose)
                self._images = self._images[self.shuffle_index_dose]
                self._labels = self._labels[self.shuffle_index_dose]

            if self._file_index > self.end_file_index :
                # Number of training epochs completed
                self._epochs_completed += 1
                self._file_index = self.start_file_index
                # Reinitialize conunter
                start = 0
                self._index_in_datafile = self.batch_size

            end = self._index_in_datafile

            return self._images[start:end], self._labels[start:end]
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号