python类fromstring()的实例源码

images2gif.py 文件源码 项目:RasterFairy 作者: Quasimondo 项目源码 文件源码 阅读 25 收藏 0 点赞 0 评论 0
def __init__(self, image, samplefac=10, colors=256):

        # Check Numpy
        if np is None:
            raise RuntimeError("Need Numpy for the NeuQuant algorithm.")

        # Check image
        if image.size[0] * image.size[1] < NeuQuant.MAXPRIME:
            raise IOError("Image is too small")
        if image.mode != "RGBA":
            raise IOError("Image mode should be RGBA.")

        # Initialize
        self.setconstants(samplefac, colors)
        self.pixels = np.fromstring(image.tostring(), np.uint32)
        self.setUpArrays()

        self.learn()
        self.fix()
        self.inxbuild()
lsun_bedroom_line2color.py 文件源码 项目:chainer-cyclegan 作者: Aixile 项目源码 文件源码 阅读 28 收藏 0 点赞 0 评论 0
def get_example(self, i):
        id = self.all_keys[i]
        img = None
        val = self.db.get(id.encode())

        img = cv2.imdecode(np.fromstring(val, dtype=np.uint8), 1)
        img = self.do_augmentation(img)

        img_color = img
        img_color = self.preprocess_image(img_color)

        img_line = XDoG(img)
        img_line = cv2.cvtColor(img_line, cv2.COLOR_GRAY2RGB)
        #if img_line.ndim == 2:
        #    img_line = img_line[:, :, np.newaxis]
        img_line = self.preprocess_image(img_line)

        return img_line, img_color
MetaArray.py 文件源码 项目:NeoAnalysis 作者: neoanalysis 项目源码 文件源码 阅读 28 收藏 0 点赞 0 评论 0
def _readData1(self, fd, meta, mmap=False, **kwds):
        ## Read array data from the file descriptor for MetaArray v1 files
        ## read in axis values for any axis that specifies a length
        frameSize = 1
        for ax in meta['info']:
            if 'values_len' in ax:
                ax['values'] = np.fromstring(fd.read(ax['values_len']), dtype=ax['values_type'])
                frameSize *= ax['values_len']
                del ax['values_len']
                del ax['values_type']
        self._info = meta['info']
        if not kwds.get("readAllData", True):
            return
        ## the remaining data is the actual array
        if mmap:
            subarr = np.memmap(fd, dtype=meta['type'], mode='r', shape=meta['shape'])
        else:
            subarr = np.fromstring(fd.read(), dtype=meta['type'])
            subarr.shape = meta['shape']
        self._data = subarr
MetaArray.py 文件源码 项目:NeoAnalysis 作者: neoanalysis 项目源码 文件源码 阅读 29 收藏 0 点赞 0 评论 0
def _readData1(self, fd, meta, mmap=False, **kwds):
        ## Read array data from the file descriptor for MetaArray v1 files
        ## read in axis values for any axis that specifies a length
        frameSize = 1
        for ax in meta['info']:
            if 'values_len' in ax:
                ax['values'] = np.fromstring(fd.read(ax['values_len']), dtype=ax['values_type'])
                frameSize *= ax['values_len']
                del ax['values_len']
                del ax['values_type']
        self._info = meta['info']
        if not kwds.get("readAllData", True):
            return
        ## the remaining data is the actual array
        if mmap:
            subarr = np.memmap(fd, dtype=meta['type'], mode='r', shape=meta['shape'])
        else:
            subarr = np.fromstring(fd.read(), dtype=meta['type'])
            subarr.shape = meta['shape']
        self._data = subarr
pyxis.py 文件源码 项目:ml-pyxis 作者: vicolab 项目源码 文件源码 阅读 31 收藏 0 点赞 0 评论 0
def decode_data(obj):
    """Decode a serialised data object.

    Parameter
    ---------
    obj : Python dictionary
        A dictionary describing a serialised data object.
    """
    try:
        if TYPES['str'] == obj[b'type']:
            return decode_str(obj[b'data'])
        elif TYPES['ndarray'] == obj[b'type']:
            return np.fromstring(obj[b'data'], dtype=np.dtype(
                obj[b'dtype'])).reshape(obj[b'shape'])
        else:
            # Assume the user know what they are doing
            return obj
    except KeyError:
        # Assume the user know what they are doing
        return obj
proposal.py 文件源码 项目:mx-rfcn 作者: giorking 项目源码 文件源码 阅读 28 收藏 0 点赞 0 评论 0
def __init__(self, feat_stride, scales, ratios, is_train=False, output_score=False):
        super(ProposalOperator, self).__init__()
        self._feat_stride = float(feat_stride)
        self._scales = np.fromstring(scales[1:-1], dtype=float, sep=',')
        self._ratios = np.fromstring(ratios[1:-1], dtype=float, sep=',').tolist()
        self._anchors = generate_anchors(base_size=self._feat_stride, scales=self._scales, ratios=self._ratios)
        self._num_anchors = self._anchors.shape[0]
        self._output_score = output_score

        if DEBUG:
            print 'feat_stride: {}'.format(self._feat_stride)
            print 'anchors:'
            print self._anchors

        if is_train:
            self.cfg_key = 'TRAIN'
        else:
            self.cfg_key = 'TEST'
utils.py 文件源码 项目:lopocs 作者: Oslandia 项目源码 文件源码 阅读 27 收藏 0 点赞 0 评论 0
def read_uncompressed_patch(pcpatch_wkb, schema):
    '''
    Patch binary structure uncompressed:
    byte:         endianness (1 = NDR, 0 = XDR)
    uint32:       pcid (key to POINTCLOUD_SCHEMAS)
    uint32:       0 = no compression
    uint32:       npoints
    pointdata[]:  interpret relative to pcid
    '''
    patchbin = unhexlify(pcpatch_wkb)
    npoints = unpack("I", patchbin[9:13])[0]
    dt = schema_dtype(schema)
    patch = np.fromstring(patchbin[13:], dtype=dt)
    # debug
    # print(patch[:10])
    return patch, npoints
utils.py 文件源码 项目:lopocs 作者: Oslandia 项目源码 文件源码 阅读 27 收藏 0 点赞 0 评论 0
def decompress(points, schema):
    """
    Decode patch encoded with lazperf.
    'points' is a pcpatch in wkb
    """

    # retrieve number of points in wkb pgpointcloud patch
    npoints = patch_numpoints(points)
    hexbuffer = unhexlify(points[34:])
    hexbuffer += hexa_signed_int32(npoints)

    # uncompress
    s = json.dumps(schema).replace("\\", "")
    dtype = buildNumpyDescription(json.loads(s))
    lazdata = bytes(hexbuffer)

    arr = np.fromstring(lazdata, dtype=np.uint8)
    d = Decompressor(arr, s)
    output = np.zeros(npoints * dtype.itemsize, dtype=np.uint8)
    decompressed = d.decompress(output)

    return decompressed
images2gif.py 文件源码 项目:CycleGAN-Tensorflow-PyTorch-Simple 作者: LynnHo 项目源码 文件源码 阅读 40 收藏 0 点赞 0 评论 0
def __init__(self, image, samplefac=10, colors=256):

        # Check Numpy
        if np is None:
            raise RuntimeError("Need Numpy for the NeuQuant algorithm.")

        # Check image
        if image.size[0] * image.size[1] < NeuQuant.MAXPRIME:
            raise IOError("Image is too small")
        if image.mode != "RGBA":
            raise IOError("Image mode should be RGBA.")

        # Initialize
        self.setconstants(samplefac, colors)
        self.pixels = np.fromstring(image.tostring(), np.uint32)
        self.setUpArrays()

        self.learn()
        self.fix()
        self.inxbuild()
Utils.py 文件源码 项目:ISLES2017 作者: MiguelMonteiro 项目源码 文件源码 阅读 25 收藏 0 点赞 0 评论 0
def get_original_image(tfrecords_dir, is_training_data=False):
    record = tf.python_io.tf_record_iterator(tfrecords_dir).next()
    example = tf.train.Example()
    example.ParseFromString(record)

    shape = np.fromstring(example.features.feature['shape'].bytes_list.value[0], dtype=np.int32)
    image = np.fromstring(example.features.feature['img_raw'].bytes_list.value[0], dtype=np.float32)
    image = image.reshape(shape)

    if is_training_data:
        ground_truth = np.fromstring(example.features.feature['gt_raw'].bytes_list.value[0], dtype=np.uint8)
        ground_truth = ground_truth.reshape(shape[:-1])
    else:
        ground_truth = None

    return image, ground_truth
prep_wikiqa_data.py 文件源码 项目:answer-triggering 作者: jiez-osu 项目源码 文件源码 阅读 36 收藏 0 点赞 0 评论 0
def load_bin_vec(self, fname, vocab):
        """
        Loads 300x1 word vecs from Google (Mikolov) word2vec
        """
        word_vecs = {}
        with open(fname, "rb") as f:
            header = f.readline()
            vocab_size, layer1_size = map(int, header.split())
            binary_len = np.dtype('float32').itemsize * layer1_size
            for line in xrange(vocab_size):
                word = []
                while True:
                    ch = f.read(1)
                    if ch == ' ':
                        word = ''.join(word)
                        break
                    if ch != '\n':
                        word.append(ch)
                if word in vocab:
                   word_vecs[word] = np.fromstring(f.read(binary_len), dtype='float32')
                else:
                    f.read(binary_len)
        logger.info("num words already in word2vec: " + str(len(word_vecs)))
        return word_vecs
vec2bin.py 文件源码 项目:hadan-gcloud 作者: youkpan 项目源码 文件源码 阅读 54 收藏 0 点赞 0 评论 0
def vec2bin(input_path, output_path):
    input_fd  = open(input_path, "rb")
    output_fd = open(output_path, "wb")

    header = input_fd.readline()
    output_fd.write(header)

    vocab_size, vector_size = map(int, header.split())

    for line in tqdm(range(vocab_size)):
        word = []
        while True:
            ch = input_fd.read(1)
            output_fd.write(ch)
            if ch == b' ':
                word = b''.join(word).decode('utf-8')
                break
            if ch != b'\n':
                word.append(ch)
        vector = np.fromstring(input_fd.readline(), sep=' ', dtype='float32')
        output_fd.write(vector.tostring())

    input_fd.close()
    output_fd.close()
models.py 文件源码 项目:SentEval 作者: facebookresearch 项目源码 文件源码 阅读 35 收藏 0 点赞 0 评论 0
def get_glove_k(self, K):
        assert hasattr(self, 'glove_path'), 'warning : \
            you need to set_glove_path(glove_path)'
        # create word_vec with k first glove vectors
        k = 0
        word_vec = {}
        with io.open(self.glove_path) as f:
            for line in f:
                word, vec = line.split(' ', 1)
                if k <= K:
                    word_vec[word] = np.fromstring(vec, sep=' ')
                    k += 1
                if k > K:
                    if word in ['<s>', '</s>']:
                        word_vec[word] = np.fromstring(vec, sep=' ')

                if k>K and all([w in word_vec for w in ['<s>', '</s>']]):
                    break
        return word_vec
utils.py 文件源码 项目:main_loop_tf 作者: fvisin 项目源码 文件源码 阅读 29 收藏 0 点赞 0 评论 0
def fig2array(fig):
    """Convert a Matplotlib figure to a 4D numpy array

    Params
    ------
    fig:
        A matplotlib figure

    Return
    ------
        A numpy 3D array of RGBA values

    Modified version of: http://www.icare.univ-lille1.fr/node/1141
    """
    # draw the renderer
    fig.canvas.draw()

    # Get the RGBA buffer from the figure
    w, h = fig.canvas.get_width_height()
    buf = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8)
    buf.shape = (h, w, 3)

    return buf
io_methods.py 文件源码 项目:mss_pytorch 作者: Js-Mim 项目源码 文件源码 阅读 25 收藏 0 点赞 0 评论 0
def _wav2array(nchannels, sampwidth, data):
        """data must be the string containing the bytes from the wav file."""
        num_samples, remainder = divmod(len(data), sampwidth * nchannels)
        if remainder > 0:
            raise ValueError('The length of data is not a multiple of '
                             'sampwidth * num_channels.')
        if sampwidth > 4:
            raise ValueError("sampwidth must not be greater than 4.")

        if sampwidth == 3:
            a = np.empty((num_samples, nchannels, 4), dtype = np.uint8)
            raw_bytes = np.fromstring(data, dtype = np.uint8)
            a[:, :, :sampwidth] = raw_bytes.reshape(-1, nchannels, sampwidth)
            a[:, :, sampwidth:] = (a[:, :, sampwidth - 1:sampwidth] >> 7) * 255
            result = a.view('<i4').reshape(a.shape[:-1])
        else:
            # 8 bit samples are stored as unsigned ints; others as signed ints.
            dt_char = 'u' if sampwidth == 1 else 'i'
            a = np.fromstring(data, dtype='<%s%d' % (dt_char, sampwidth))
            result = a.reshape(-1, nchannels)
        return result
_tifffile.py 文件源码 项目:radar 作者: amoose136 项目源码 文件源码 阅读 28 收藏 0 点赞 0 评论 0
def read_array(self, dtype, count=-1, sep=""):
        """Return numpy array from file.

        Work around numpy issue #2230, "numpy.fromfile does not accept
        StringIO object" https://github.com/numpy/numpy/issues/2230.

        """
        try:
            return numpy.fromfile(self._fh, dtype, count, sep)
        except IOError:
            if count < 0:
                size = self._size
            else:
                size = count * numpy.dtype(dtype).itemsize
            data = self._fh.read(size)
            return numpy.fromstring(data, dtype, count, sep)
process_sst2_data.py 文件源码 项目:crnn 作者: ultimate010 项目源码 文件源码 阅读 26 收藏 0 点赞 0 评论 0
def load_bin_vec(fname, vocab):
    """
    Loads 300x1 word vecs from Google (Mikolov) word2vec
    """
    word_vecs = {}
    with open(fname, "rb") as f:
        header = f.readline()
        vocab_size, layer1_size = map(int, header.split())
        binary_len = np.dtype('float32').itemsize * layer1_size
        for line in xrange(vocab_size):
            word = []
            while True:
                ch = f.read(1)
                if ch == ' ':
                    word = ''.join(word)
                    break
                if ch != '\n':
                    word.append(ch)
            if word in vocab:
               word_vecs[word] = np.fromstring(f.read(binary_len), dtype='float32')
            else:
                f.read(binary_len)
    return word_vecs
process_mr_data.py 文件源码 项目:crnn 作者: ultimate010 项目源码 文件源码 阅读 25 收藏 0 点赞 0 评论 0
def load_bin_vec(fname, vocab):
    """
    Loads 300x1 word vecs from Google (Mikolov) word2vec
    """
    word_vecs = {}
    with open(fname, "rb") as f:
        header = f.readline()
        vocab_size, layer1_size = map(int, header.split())
        binary_len = np.dtype('float32').itemsize * layer1_size
        for line in xrange(vocab_size):
            word = []
            while True:
                ch = f.read(1)
                if ch == ' ':
                    word = ''.join(word)
                    break
                if ch != '\n':
                    word.append(ch)
            if word in vocab:
               word_vecs[word] = np.fromstring(f.read(binary_len), dtype='float32')
            else:
                f.read(binary_len)
    return word_vecs
process_sst1_data.py 文件源码 项目:crnn 作者: ultimate010 项目源码 文件源码 阅读 23 收藏 0 点赞 0 评论 0
def load_bin_vec(fname, vocab):
    """
    Loads 300x1 word vecs from Google (Mikolov) word2vec
    """
    word_vecs = {}
    with open(fname, "rb") as f:
        header = f.readline()
        vocab_size, layer1_size = map(int, header.split())
        binary_len = np.dtype('float32').itemsize * layer1_size
        for line in xrange(vocab_size):
            word = []
            while True:
                ch = f.read(1)
                if ch == ' ':
                    word = ''.join(word)
                    break
                if ch != '\n':
                    word.append(ch)
            if word in vocab:
               word_vecs[word] = np.fromstring(f.read(binary_len), dtype='float32')
            else:
                f.read(binary_len)
    return word_vecs
speech_data.py 文件源码 项目:skill-voice-recognition 作者: TREE-Edu 项目源码 文件源码 阅读 25 收藏 0 点赞 0 评论 0
def load_wav_file(name):
    f = wave.open(name, "rb")
    # print("loading %s"%name)
    chunk = []
    data0 = f.readframes(CHUNK)
    while data0:  # f.getnframes()
        # data=numpy.fromstring(data0, dtype='float32')
        # data = numpy.fromstring(data0, dtype='uint16')
        data = numpy.fromstring(data0, dtype='uint8')
        data = (data + 128) / 255.  # 0-1 for Better convergence
        # chunks.append(data)
        chunk.extend(data)
        data0 = f.readframes(CHUNK)
    # finally trim:
    chunk = chunk[0:CHUNK * 2]  # should be enough for now -> cut
    chunk.extend(numpy.zeros(CHUNK * 2 - len(chunk)))  # fill with padding 0's
    # print("%s loaded"%name)
    return chunk
utils.py 文件源码 项目:self-augmented-net 作者: msraig 项目源码 文件源码 阅读 24 收藏 0 点赞 0 评论 0
def pfmFromBuffer(buffer, reverse = 1):
    sStream = cStringIO.StringIO(buffer)

    color = None
    width = None
    height = None
    scale = None
    endian = None

    header = sStream.readline().rstrip()
    color = (header == 'PF')

    width, height = map(int, sStream.readline().strip().split(' '))
    scale = float(sStream.readline().rstrip())
    endian = '<' if(scale < 0) else '>'
    scale = abs(scale)


    rawdata = np.fromstring(sStream.read(), endian + 'f')
    shape = (height, width, 3) if color else (height, width)
    sStream.close()
    if(len(shape) == 3):
        return rawdata.reshape(shape).astype(np.float32)[:,:,::-1]
    else:
        return rawdata.reshape(shape).astype(np.float32)
common.py 文件源码 项目:HyperGAN 作者: 255BITS 项目源码 文件源码 阅读 33 收藏 0 点赞 0 评论 0
def sample(self, filename, save_samples):
        gan = self.gan
        generator = gan.generator.sample

        sess = gan.session
        config = gan.config
        x_v, z_v = sess.run([gan.inputs.x, gan.encoder.z])

        sample = sess.run(generator, {gan.inputs.x: x_v, gan.encoder.z: z_v})

        plt.clf()
        fig = plt.figure(figsize=(3,3))
        plt.scatter(*zip(*x_v), c='b')
        plt.scatter(*zip(*sample), c='r')
        plt.xlim([-2, 2])
        plt.ylim([-2, 2])
        plt.ylabel("z")
        fig.canvas.draw()
        data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='')
        data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
        #plt.savefig(filename)
        self.plot(data, filename, save_samples)
        return [{'image': filename, 'label': '2d'}]
IOMethods.py 文件源码 项目:aes_wimp 作者: Js-Mim 项目源码 文件源码 阅读 28 收藏 0 点赞 0 评论 0
def _wav2array(nchannels, sampwidth, data):
        """data must be the string containing the bytes from the wav file."""
        num_samples, remainder = divmod(len(data), sampwidth * nchannels)
        if remainder > 0:
            raise ValueError('The length of data is not a multiple of '
                             'sampwidth * num_channels.')
        if sampwidth > 4:
            raise ValueError("sampwidth must not be greater than 4.")

        if sampwidth == 3:
            a = np.empty((num_samples, nchannels, 4), dtype = np.uint8)
            raw_bytes = np.fromstring(data, dtype = np.uint8)
            a[:, :, :sampwidth] = raw_bytes.reshape(-1, nchannels, sampwidth)
            a[:, :, sampwidth:] = (a[:, :, sampwidth - 1:sampwidth] >> 7) * 255
            result = a.view('<i4').reshape(a.shape[:-1])
        else:
            # 8 bit samples are stored as unsigned ints; others as signed ints.
            dt_char = 'u' if sampwidth == 1 else 'i'
            a = np.fromstring(data, dtype='<%s%d' % (dt_char, sampwidth))
            result = a.reshape(-1, nchannels)
        return result
odometry.py 文件源码 项目:canshi 作者: hungsing92 项目源码 文件源码 阅读 30 收藏 0 点赞 0 评论 0
def load_poses(self):
        """Load ground truth poses from file."""
        print('Loading poses for sequence ' + self.sequence + '...')

        pose_file = os.path.join(self.pose_path, self.sequence + '.txt')

        # Read and parse the poses
        try:
            self.T_w_cam0 = []
            with open(pose_file, 'r') as f:
                for line in f.readlines():
                    T = np.fromstring(line, dtype=float, sep=' ')
                    T = T.reshape(3, 4)
                    T = np.vstack((T, [0, 0, 0, 1]))
                    self.T_w_cam0.append(T)
            print('done.')

        except FileNotFoundError:
            print('Ground truth poses are not avaialble for sequence ' +
                  self.sequence + '.')
mnist_training.py 文件源码 项目:ngraph 作者: NervanaSystems 项目源码 文件源码 阅读 24 收藏 0 点赞 0 评论 0
def loadData(src, cimg):
    gzfname, h = urlretrieve(src, './delete.me')
    try:
        with gzip.open(gzfname) as gz:
            n = struct.unpack('I', gz.read(4))
            if n[0] != 0x3080000:
                raise Exception('Invalid file: unexpected magic number.')
            n = struct.unpack('>I', gz.read(4))[0]
            if n != cimg:
                raise Exception('Invalid file: expected {0} entries.'.format(cimg))
            crow = struct.unpack('>I', gz.read(4))[0]
            ccol = struct.unpack('>I', gz.read(4))[0]
            if crow != 28 or ccol != 28:
                raise Exception('Invalid file: expected 28 rows/cols per image.')
            res = np.fromstring(gz.read(cimg * crow * ccol), dtype=np.uint8)
    finally:
        os.remove(gzfname)
    return res.reshape((cimg, crow * ccol))
mnist_softmax_cntk.py 文件源码 项目:ai-gym 作者: tuzzer 项目源码 文件源码 阅读 29 收藏 0 点赞 0 评论 0
def get_mnist_data(filename, num_samples, local_data_dir):

    gzfname = load_or_download_mnist_files(filename, num_samples, local_data_dir)

    with gzip.open(gzfname) as gz:
        n = struct.unpack('I', gz.read(4))
        # Read magic number.
        if n[0] != 0x3080000:
            raise Exception('Invalid file: unexpected magic number.')
        # Read number of entries.
        n = struct.unpack('>I', gz.read(4))[0]
        if n != num_samples:
            raise Exception('Invalid file: expected {0} entries.'.format(num_samples))
        crow = struct.unpack('>I', gz.read(4))[0]
        ccol = struct.unpack('>I', gz.read(4))[0]
        if crow != 28 or ccol != 28:
            raise Exception('Invalid file: expected 28 rows/cols per image.')
        # Read data.
        res = np.fromstring(gz.read(num_samples * crow * ccol), dtype = np.uint8)

        return res.reshape((num_samples, crow * ccol))
mnist_softmax_cntk.py 文件源码 项目:ai-gym 作者: tuzzer 项目源码 文件源码 阅读 29 收藏 0 点赞 0 评论 0
def get_mnist_labels(filename, num_samples, local_data_dir):

    gzfname = load_or_download_mnist_files(filename, num_samples, local_data_dir)

    with gzip.open(gzfname) as gz:
        n = struct.unpack('I', gz.read(4))
        # Read magic number.
        if n[0] != 0x1080000:
            raise Exception('Invalid file: unexpected magic number.')
        # Read number of entries.
        n = struct.unpack('>I', gz.read(4))
        if n[0] != num_samples:
            raise Exception('Invalid file: expected {0} rows.'.format(num_samples))
        # Read labels.
        res = np.fromstring(gz.read(num_samples), dtype = np.uint8)

        return res.reshape((num_samples, 1))
tensor.py 文件源码 项目:PySyft 作者: OpenMined 项目源码 文件源码 阅读 32 收藏 0 点赞 0 评论 0
def shape(self, as_list=True):
        """
        Returns the size of the self tensor as a FloatTensor (or as List).
        Note:
            The returned value currently is a FloatTensor because it leverages
            the messaging mechanism with Unity.
        Parameters
        ----------
        as_list : bool
            Value retruned as list if true; else as tensor
        Returns
        -------
        FloatTensor
            Output tensor
        (or)
        Iterable
            Output list
        """
        if (as_list):
            return list(np.fromstring(self.get("shape")[:-1], sep=",").astype('int'))
        else:
            shape_tensor = self.no_params_func("shape", return_response=True)
            return shape_tensor
tensor.py 文件源码 项目:PySyft 作者: OpenMined 项目源码 文件源码 阅读 30 收藏 0 点赞 0 评论 0
def stride(self, dim=-1):
        """
        Returns the stride of tensor.
        Parameters
        ----------
        dim : int
            dimension of expected return

        Returns
        -------
        FloatTensor
            Output tensor.
        (or)
        numpy.ndarray
            NumPy Array as Long
        """
        if dim == -1:
            return self.no_params_func("stride", return_response=True, return_type=None)
        else:
            strides = self.params_func("stride", [dim], return_response=True, return_type=None)
            return np.fromstring(strides, sep=' ').astype('long')
proposal.py 文件源码 项目:focal-loss 作者: unsky 项目源码 文件源码 阅读 39 收藏 0 点赞 0 评论 0
def __init__(self, feat_stride, scales, ratios, output_score,
                 rpn_pre_nms_top_n, rpn_post_nms_top_n, threshold, rpn_min_size):
        super(ProposalOperator, self).__init__()
        self._feat_stride = feat_stride
        self._scales = np.fromstring(scales[1:-1], dtype=float, sep=',')
        self._ratios = np.fromstring(ratios[1:-1], dtype=float, sep=',')
        self._anchors = generate_anchors(base_size=self._feat_stride, scales=self._scales, ratios=self._ratios)
        self._num_anchors = self._anchors.shape[0]
        self._output_score = output_score
        self._rpn_pre_nms_top_n = rpn_pre_nms_top_n
        self._rpn_post_nms_top_n = rpn_post_nms_top_n
        self._threshold = threshold
        self._rpn_min_size = rpn_min_size

        if DEBUG:
            print 'feat_stride: {}'.format(self._feat_stride)
            print 'anchors:'
            print self._anchors


问题


面经


文章

微信
公众号

扫码关注公众号