def uploadFile(current_user):
format = "%Y-%m-%dT%H:%M:%S"
now = datetime.datetime.utcnow().strftime(format)
try:
file = request.files['file']
except:
file = None
try:
url = request.form['url']
except:
url = None
if file and allowed_file(file.filename):
filename = now + '_' +str(current_user) + '_' + file.filename
filename = secure_filename(filename)
file.save(os.path.join(app.config['UPLOAD_FOLDER'], filename))
file_uploaded = True
elif url:
file = urllib.urlopen(url)
filename = url.split('/')[-1]
filename = now + '_' +str(current_user) + '_' + filename
filename = secure_filename(filename)
if file and allowed_file(filename):
open(os.path.join(app.config['UPLOAD_FOLDER'], filename),
'wb').write(file.read())
file_uploaded = True
else:
filename = None
file_uploaded = False
return file_uploaded, filename
python类open()的实例源码
def extract_images(filename):
"""Extract the images into a 4D uint8 numpy array [index, y, x, depth]."""
print('Extracting', filename)
with gzip.open(filename) as bytestream:
magic = _read32(bytestream)
if magic != 2051:
raise ValueError(
'Invalid magic number %d in MNIST image file: %s' %
(magic, filename))
num_images = _read32(bytestream)
rows = _read32(bytestream)
cols = _read32(bytestream)
buf = bytestream.read(rows * cols * num_images)
data = numpy.frombuffer(buf, dtype=numpy.uint8)
data = data.reshape(num_images, rows, cols, 1)
return data
def lcdict_to_pickle(lcdict, outfile=None):
'''This just writes the lcdict to a pickle.
If outfile is None, then will try to get the name from the
lcdict['objectid'] and write to <objectid>-hptxtlc.pkl. If that fails, will
write to a file named hptxtlc.pkl'.
'''
if not outfile and lcdict['objectid']:
outfile = '%s-hplc.pkl' % lcdict['objectid']
elif not outfile and not lcdict['objectid']:
outfile = 'hplc.pkl'
with open(outfile,'wb') as outfd:
pickle.dump(lcdict, outfd, protocol=pickle.HIGHEST_PROTOCOL)
if os.path.exists(outfile):
LOGINFO('lcdict for object: %s -> %s OK' % (lcdict['objectid'],
outfile))
return outfile
else:
LOGERROR('could not make a pickle for this lcdict!')
return None
def extract_images(filename):
"""Extract the images into a 4D uint8 numpy array [index, y, x, depth]."""
print('Extracting', filename)
with gzip.open(filename) as bytestream:
magic = _read32(bytestream)
if magic != 2051:
raise ValueError(
'Invalid magic number %d in MNIST image file: %s' %
(magic, filename))
num_images = _read32(bytestream)
rows = _read32(bytestream)
cols = _read32(bytestream)
buf = bytestream.read(rows * cols * num_images)
data = numpy.frombuffer(buf, dtype=numpy.uint8)
data = data.reshape(num_images, rows, cols, 1)
return data
test_run_no_updates_available.py 文件源码
项目:pyupdater-wx-demo
作者: wettenhj
项目源码
文件源码
阅读 25
收藏 0
点赞 0
评论 0
def setUp(self):
tempFile = tempfile.NamedTemporaryFile()
self.fileServerDir = tempFile.name
tempFile.close()
os.mkdir(self.fileServerDir)
os.environ['PYUPDATER_FILESERVER_DIR'] = self.fileServerDir
privateKey = ed25519.SigningKey(PRIVATE_KEY.encode('utf-8'),
encoding='base64')
signature = privateKey.sign(six.b(json.dumps(VERSIONS, sort_keys=True)),
encoding='base64').decode()
VERSIONS['signature'] = signature
keysFilePath = os.path.join(self.fileServerDir, 'keys.gz')
with gzip.open(keysFilePath, 'wb') as keysFile:
keysFile.write(json.dumps(KEYS, sort_keys=True))
versionsFilePath = os.path.join(self.fileServerDir, 'versions.gz')
with gzip.open(versionsFilePath, 'wb') as versionsFile:
versionsFile.write(json.dumps(VERSIONS, sort_keys=True))
os.environ['WXUPDATEDEMO_TESTING'] = 'True'
from wxupdatedemo.config import CLIENT_CONFIG
self.clientConfig = CLIENT_CONFIG
self.clientConfig.PUBLIC_KEY = PUBLIC_KEY
def setUp(self):
tempFile = tempfile.NamedTemporaryFile()
self.fileServerDir = tempFile.name
tempFile.close()
os.mkdir(self.fileServerDir)
os.environ['PYUPDATER_FILESERVER_DIR'] = self.fileServerDir
privateKey = ed25519.SigningKey(PRIVATE_KEY.encode('utf-8'),
encoding='base64')
signature = privateKey.sign(six.b(json.dumps(VERSIONS, sort_keys=True)),
encoding='base64').decode()
VERSIONS['signature'] = signature
keysFilePath = os.path.join(self.fileServerDir, 'keys.gz')
with gzip.open(keysFilePath, 'wb') as keysFile:
keysFile.write(json.dumps(KEYS, sort_keys=True))
versionsFilePath = os.path.join(self.fileServerDir, 'versions.gz')
with gzip.open(versionsFilePath, 'wb') as versionsFile:
versionsFile.write(json.dumps(VERSIONS, sort_keys=True))
os.environ['WXUPDATEDEMO_TESTING'] = 'True'
from wxupdatedemo.config import CLIENT_CONFIG
self.clientConfig = CLIENT_CONFIG
self.clientConfig.PUBLIC_KEY = PUBLIC_KEY
self.clientConfig.APP_NAME = APP_NAME
def main():
args = get_args()
logging.basicConfig(
format='%(asctime)s %(message)s',
filename=os.path.join(args.outdir, "NanoQC.log"),
level=logging.INFO)
logging.info("NanoQC started.")
sizeRange = length_histogram(
fqin=gzip.open(args.fastq, 'rt'),
name=os.path.join(args.outdir, "SequenceLengthDistribution.png"))
fq = get_bin(gzip.open(args.fastq, 'rt'), sizeRange)
logging.info("Using {} reads for plotting".format(len(fq)))
fqbin = [dat[0] for dat in fq]
qualbin = [dat[1] for dat in fq]
logging.info("Creating plots...")
per_base_sequence_content_and_quality(fqbin, qualbin, args.outdir, args.format)
logging.info("per base sequence content and quality completed.")
logging.info("Finished!")
def __init__(self, source,
source_dict,
batch_size=128,
maxlen=100,
minlen=0,
n_words_source=-1):
if source.endswith('.gz'):
self.source = gzip.open(source, 'r')
else:
self.source = open(source, 'r')
self.source_dict = {'1': 1, '0': 0, 0: 3}
self.batch_size = batch_size
self.maxlen = maxlen
self.minlen = 10
self.n_words_source = n_words_source
self.end_of_data = False
def export(metadata, start, end, container_image_pattern):
queries = []
metadata["start"] = start.isoformat() + "Z"
metadata["end"] = end.isoformat() + "Z"
metadata["services"] = []
ts = datetime.utcnow().strftime("%Y%m%d%H%M%S-")
path = os.path.join(metadata["metrics_export"], ts + metadata["measurement_name"])
if not os.path.isdir(path):
os.makedirs(path)
for app in APPS:
metadata["services"].append(dump_app(app, path, start, end, container_image_pattern))
with open(os.path.join(path, "metadata.json"), "w+") as f:
json.dump(metadata, f, cls=Encoder, sort_keys=True, indent=4)
f.flush()
def check_fastq(fastq):
# Check if fastq is readable
if not os.access(fastq, os.R_OK):
martian.exit("Do not have file read permission for FASTQ file: %s" % fastq)
# Check if fastq is gzipped
is_gzip_fastq = True
try:
with gzip.open(fastq) as f:
f.read(1)
except:
is_gzip_fastq = False
if is_gzip_fastq and not fastq.endswith(cr_constants.GZIP_SUFFIX):
martian.exit("Input FASTQ file is gzipped but filename does not have %s suffix: %s" % (fastq, cr_constants.GZIP_SUFFIX))
if not is_gzip_fastq and fastq.endswith(cr_constants.GZIP_SUFFIX):
martian.exit("Input FASTQ file is not gzipped but filename has %s suffix: %s" % (fastq, cr_constants.GZIP_SUFFIX))
def get_run_data(fn):
""" Parse flowcell + lane from the first FASTQ record.
NOTE: we don't check whether there are multiple FC / lanes in this file.
NOTE: taken from longranger/mro/stages/reads/setup_chunks
"""
if fn[-2:] == 'gz':
reader = gzip.open(fn)
else:
reader = open(fn, 'r')
gen = read_generator_fastq(reader)
try:
(name, seq, qual) = gen.next()
(flowcell, lane) = re.split(':', name)[2:4]
return (flowcell, lane)
except StopIteration:
# empty fastq
raise ValueError('Could not extract flowcell and lane from FASTQ file. File is empty: %s' % fn)
def load_primary_contigs(reference_path):
'''Load set of primary contigs for variant and SV calling from reference_path.
If now primary_contigs.txt file is specified, return all contigs. If reference_path
is a known 10x reference genome and has no primary_contigs.txt, filter the known bad contigs '''
if not reference_path is None and os.path.exists(get_primary_contigs(reference_path)):
# If we have a primary_contigs.txt file, use it
with open(get_primary_contigs(reference_path), 'r') as f:
primary_contigs = set([line.strip() for line in f.readlines()])
else:
# Default is to include all contigs
# Otherwise implement the old contig filters
ref = open_reference(reference_path)
primary_contigs = set(ref.keys())
if is_tenx(reference_path):
primary_contigs = set(chrom for chrom in primary_contigs if not ('random' in chrom or 'U' in chrom or 'hap' in chrom or chrom == 'hs37d5'))
return primary_contigs
def load_fastq(filename):
reads = []
if get_compression_type(filename) == 'gz':
open_func = gzip.open
else: # plain text
open_func = open
with open_func(filename, 'rb') as fastq:
for line in fastq:
stripped_line = line.strip()
if len(stripped_line) == 0:
continue
if not stripped_line.startswith(b'@'):
continue
name = stripped_line[1:].split()[0]
sequence = next(fastq).strip()
_ = next(fastq)
qualities = next(fastq).strip()
reads.append((name, sequence, qualities))
return reads
def get_compression_type(filename):
"""
Attempts to guess the compression (if any) on a file using the first few bytes.
http://stackoverflow.com/questions/13044562
"""
magic_dict = {'gz': (b'\x1f', b'\x8b', b'\x08'),
'bz2': (b'\x42', b'\x5a', b'\x68'),
'zip': (b'\x50', b'\x4b', b'\x03', b'\x04')}
max_len = max(len(x) for x in magic_dict)
unknown_file = open(filename, 'rb')
file_start = unknown_file.read(max_len)
unknown_file.close()
compression_type = 'plain'
for file_type, magic_bytes in magic_dict.items():
if file_start.startswith(magic_bytes):
compression_type = file_type
if compression_type == 'bz2':
sys.exit('Error: cannot use bzip2 format - use gzip instead')
if compression_type == 'zip':
sys.exit('Error: cannot use zip format - use gzip instead')
return compression_type
def _make_writer(self):
"""
:return:
"""
self._buffer = StringIO()
self._bytes_written = 0
now = datetime.now()
self.fname = self.log_folder + '/' + now.strftime('%Y%m%d_%H%M%S_{}.json'.format(self.make_random(6)))
self.fname = str(pathlib.Path(self.fname))
self._out_fh = open(self.fname, 'w')
self.write_pid()
logging.warning("Writing to {} ({} bytes)".format(self._out_fh.name, self.max_bytes))
# compress any old files still lying around
for fname in glob(self.log_folder+"/*.json"):
if fname != self.fname:
self._compress(fname)
def test_save_svgz_filename():
import gzip
qr = segno.make_qr('test')
f = tempfile.NamedTemporaryFile('wb', suffix='.svgz', delete=False)
f.close()
qr.save(f.name)
f = open(f.name, mode='rb')
expected = b'\x1f\x8b\x08' # gzip magic number
val = f.read(len(expected))
f.close()
f = gzip.open(f.name)
try:
content = f.read(6)
finally:
f.close()
os.unlink(f.name)
assert expected == val
assert b'<?xml ' == content
def QuASAR_rep_wrapper(outdir,parameters,samplename1,samplename2,running_mode):
script_comparison_file=outdir+'/scripts/QuASAR-Rep/'+samplename1+'.vs.'+samplename2+'/'+samplename1+'.vs.'+samplename2+'.QuASAR-Rep.sh'
subp.check_output(['bash','-c','mkdir -p '+os.path.dirname(script_comparison_file)])
script_comparison=open(script_comparison_file,'w')
script_comparison.write("#!/bin/sh"+'\n')
script_comparison.write('. '+bashrc_file+'\n')
outpath=outdir+'/results/reproducibility/'+samplename1+'.vs.'+samplename2+'/QuASAR-Rep/'+samplename1+'.vs.'+samplename2+'.QuASAR-Rep.scores.txt'
subp.check_output(['bash','-c','mkdir -p '+os.path.dirname(outpath)])
quasar_data=outdir+'/data/forQuASAR'
quasar_transform1=quasar_data+'/'+samplename1+'.quasar_transform'
quasar_transform2=quasar_data+'/'+samplename2+'.quasar_transform'
script_comparison.write('${mypython} '+os.path.dirname(os.path.dirname(os.path.abspath(os.path.dirname(os.path.realpath(__file__)))))+"/hifive/bin/find_quasar_replicate_score"+' '+quasar_transform1+' '+quasar_transform2+' '+outpath+'\n')
script_comparison.write('${mypython} '+os.path.abspath(os.path.dirname(os.path.realpath(__file__)))+"/plot_quasar_scatter.py"+' '+quasar_transform1+' '+quasar_transform2+' '+outpath+'\n')
#split the scores by chromosomes
script_comparison.write('${mypython} '+os.path.abspath(os.path.dirname(os.path.realpath(__file__)))+"/quasar_split_by_chromosomes.py"+' '+outpath+'\n')
script_comparison.close()
run_script(script_comparison_file,running_mode)
def HiCSpector_wrapper(outdir,parameters,concise_analysis,samplename1,samplename2,chromo,running_mode,f1,f2,nodefile):
script_comparison_file=outdir+'/scripts/HiC-spector/'+samplename1+'.'+samplename2+'/'+chromo+'.'+samplename1+'.'+samplename2+'.sh'
subp.check_output(['bash','-c','mkdir -p '+os.path.dirname(script_comparison_file)])
script_comparison=open(script_comparison_file,'w')
script_comparison.write("#!/bin/sh"+'\n')
script_comparison.write('. '+bashrc_file+'\n')
if os.path.isfile(f1) and os.path.getsize(f1)>20:
if os.path.isfile(f2) and os.path.getsize(f2)>20:
outpath=outdir+'/results/reproducibility/'+samplename1+'.vs.'+samplename2+'/HiC-Spector/'+chromo+'.'+samplename1+'.vs.'+samplename2+'.scores.txt'
subp.check_output(['bash','-c','mkdir -p '+os.path.dirname(outpath)])
script_comparison.write("$mypython -W ignore "+os.path.abspath(os.path.dirname(os.path.dirname(os.path.realpath(__file__))))+"/reproducibility_analysis/hic-spector_wrapper.py --m1 "+f1+" --m2 "+f2+" --out "+outpath+".printout --node_file "+nodefile+" --num_evec "+parameters['HiC-Spector']['n']+"\n")
script_comparison.write("cat "+outpath+".printout | tail -n1 | cut -f2 | awk '{print \""+samplename1+"\\t"+samplename2+"\\t\"$3}' > "+outpath+'\n')
script_comparison.write("rm "+outpath+".printout"+'\n')
script_comparison.close()
run_script(script_comparison_file,running_mode)
def GenomeDISCO_wrapper(outdir,parameters,concise_analysis,samplename1,samplename2,chromo,running_mode,f1,f2,nodefile):
script_comparison_file=outdir+'/scripts/GenomeDISCO/'+samplename1+'.'+samplename2+'/'+chromo+'.'+samplename1+'.'+samplename2+'.sh'
subp.check_output(['bash','-c','mkdir -p '+os.path.dirname(script_comparison_file)])
script_comparison=open(script_comparison_file,'w')
script_comparison.write("#!/bin/sh"+'\n')
script_comparison.write('. '+bashrc_file+'\n')
if os.path.isfile(f1) and os.path.getsize(f1)>20:
if os.path.isfile(f2) and os.path.getsize(f2)>20:
concise_analysis_text=''
if concise_analysis:
concise_analysis_text=' --concise_analysis'
#get the sample that goes for subsampling
subsampling=parameters['GenomeDISCO']['subsampling']
if parameters['GenomeDISCO']['subsampling']!='NA' and parameters['GenomeDISCO']['subsampling']!='lowest':
subsampling_sample=parameters['GenomeDISCO']['subsampling']
subsampling=outdir+'/data/edges/'+subsampling_sample+'/'+subsampling_sample+'.'+chromo+'.gz'
outpath=outdir+'/results/reproducibility/'+samplename1+'.vs.'+samplename2+'/GenomeDISCO/'
subp.check_output(['bash','-c','mkdir -p '+outpath])
script_comparison.write("$mypython -W ignore "+os.path.abspath(os.path.dirname(os.path.dirname(os.path.realpath(__file__)))+"/genomedisco/compute_reproducibility.py")+" --m1 "+f1+" --m2 "+f2+" --m1name "+samplename1+" --m2name "+samplename2+" --node_file "+nodefile+" --outdir "+outpath+" --outpref "+chromo+" --m_subsample "+subsampling+" --approximation 10000000 --norm "+parameters['GenomeDISCO']['norm']+" --method RandomWalks "+" --tmin "+parameters['GenomeDISCO']['tmin']+" --tmax "+parameters['GenomeDISCO']['tmax']+concise_analysis_text+'\n')
script_comparison.close()
run_script(script_comparison_file,running_mode)
def construct_csr_matrix_from_data_and_nodes(f,nodes,blacklisted_nodes,remove_diag=True):
print "GenomeDISCO | "+strftime("%c")+" | processing: Loading interaction data from "+f
total_nodes=len(nodes.keys())
i=[]
j=[]
v=[]
#print strftime("%c")
c=0
for line in gzip.open(f):
items=line.strip().split('\t')
n1,n2,val=nodes[items[0]]['idx'],nodes[items[1]]['idx'],float(items[2])
i.append(n1)
j.append(n2)
v.append(val)
c+=1
csr_m=csr_matrix( (v,(i,j)), shape=(total_nodes,total_nodes),dtype=float)
if remove_diag:
csr_m.setdiag(0)
return filter_nodes(csr_m,blacklisted_nodes)
def dump_to_csv(self, output_csv, input_fields, write_header=True, top_level=False, mode='a', encoding='utf-8', compression=None):
if compression == 'bz2':
mode = binary_mode(mode)
filehandle = bz2.open(output_csv, mode)
elif compression == 'gzip':
mode = binary_mode(mode)
filehandle = gzip.open(output_csv, mode)
else:
filehandle = open(output_csv, mode)
writer = csv.writer(filehandle)
if write_header:
writer.writerow(input_fields)
tweet_parser = TweetParser()
for tweet in self.get_iterator():
if top_level:
ret = list(zip(input_fields, [tweet.get(field) for field in input_fields]))
else:
ret = tweet_parser.parse_columns_from_tweet(tweet,input_fields)
ret_values = [col_val[1] for col_val in ret]
writer.writerow(ret_values)
filehandle.close()
def get_iterator(self):
tweet_parser = TweetParser()
if self.compression == 'bz2':
self.mode = binary_mode(self.mode)
csv_handle = bz2.open(self.filepath, self.mode, encoding=self.encoding)
elif self.compression == 'gzip':
self.mode = binary_mode(self.mode)
csv_handle = gzip.open(self.filepath, self.mode, encoding=self.encoding)
else:
csv_handle = open(self.filepath, self.mode, encoding=self.encoding)
for count, tweet in enumerate(csv.DictReader(csv_handle)):
if self.limit < count+1 and self.limit != 0:
csv_handle.close()
return
elif tweet_parser.tweet_passes_filter(self.filter, tweet) \
and tweet_parser.tweet_passes_custom_filter_list(self.custom_filters, tweet):
if self.should_strip:
yield tweet_parser.strip_tweet(self.keep_fields, tweet)
else:
yield dict(tweet)
csv_handle.close()
def setup(self, config):
"""
Load name model (word list) and compile regexes for stop characters.
:param config: Configuration object.
:type config: ``dict``
"""
reference_model = os.path.join(
config[helper.CODE_ROOT], config[helper.NAME_MODEL])
self.stopper = regex.compile(('(%s)' % '|'.join([
'and', 'or', 'og', 'eller', r'\?', '&', '<', '>', '@', ':', ';', '/',
r'\(', r'\)', 'i', 'of', 'from', 'to', r'\n', '!'])),
regex.I | regex.MULTILINE)
self.semistop = regex.compile(
('(%s)' % '|'.join([','])), regex.I | regex.MULTILINE)
self.size_probability = [0.000, 0.000, 0.435, 0.489, 0.472, 0.004, 0.000]
self.threshold = 0.25
self.candidates = defaultdict(int)
with gzip.open(reference_model, 'rb') as inp:
self.model = json.loads(inp.read().decode('utf-8'))
self.tokenizer = regex.compile(r'\w{2,20}')
luna_preprocessed_load_data.py 文件源码
项目:lung-cancer-detector
作者: YichenGong
项目源码
文件源码
阅读 21
收藏 0
点赞 0
评论 0
def next_batch(self, batch_size):
assert self.train_mode or self.validation_mode, "Please set mode, train, validation or test. e.g. DataLoad.train()"
idx_next_batch = [(self.current_idx + i)%len(self.p_imgs) for i in range(self.batch_size)]
patient_img_next_batch = [ self.p_imgs[idx] for idx in idx_next_batch]
batch_image = []
batch_mask = []
for image in patient_img_next_batch:
fi = gzip.open(self.data_path + image, 'rb')
img = pickle.load(fi)
img = np.expand_dims(img, axis=2)
batch_image.append(img)
fi.close()
fm = gzip.open(self.mask_path + image, 'rb')
mask = pickle.load(fm)
fm.close()
mask_binary_class = np.zeros([mask.shape[0],mask.shape[1],2])
mask_binary_class[:,:,0][mask == 0] = 1
mask_binary_class[:,:,1][mask == 1] = 1
batch_mask.append(mask_binary_class)
self.current_idx = (self.current_idx + batch_size) % len(self.p_imgs)
batched_image = np.stack(batch_image)
batched_mask = np.stack(batch_mask)
return batched_image, batched_mask
def read_fakelc(fakelcfile):
'''
This just reads a pickled fake LC.
'''
try:
with open(lcfile,'rb') as infd:
lcdict = pickle.load(infd)
except UnicodeDecodeError:
with open(lcfile,'rb') as infd:
lcdict = pickle.load(infd, encoding='latin1')
return lcdict
#######################
## UTILITY FUNCTIONS ##
#######################
def read_pklc(lcfile):
'''
This just reads a pickle.
'''
try:
with open(lcfile,'rb') as infd:
lcdict = pickle.load(infd)
except UnicodeDecodeError:
with open(lcfile,'rb') as infd:
lcdict = pickle.load(infd, encoding='latin1')
return lcdict
# these translate filter operators given as strings to Python operators
def writetoHTML(html_file, defaultInfo):
html_handle = open(html_file, 'w')
current_dir = os.path.dirname(__file__)
with open(current_dir + '/lib/template.html') as report:
for line in report:
line = line.strip()
print(line, file=html_handle)
try:
start_index = line.index("^^")
stop_index = line.index("$$")
if (line[start_index+2: stop_index] == 'defaultInfo'):
print(defaultInfo, file=html_handle)
else:
file_path = current_dir + '/lib' + line[start_index+2: stop_index]
with open(file_path) as fh:
for subline in fh:
subline = subline.strip()
print(subline, file=html_handle)
except ValueError:
pass
html_handle.close()
print("HTML report successfully saved to " + html_file)
def get_targetids(filter_seq_ids, target_seq_ids):
target_ids = univset()
if filter_seq_ids:
target_ids = univset()
filter_ids = []
with open(filter_seq_ids) as fh:
for line in fh:
line = line.strip()
line = line.lstrip('>')
filter_ids.append(line)
target_ids = target_ids - set(filter_ids)
elif target_seq_ids:
target_ids = []
with open(target_seq_ids) as fh:
for line in fh:
line = line.strip()
line = line.lstrip('>')
target_ids.append(line)
target_ids = set(target_ids)
return target_ids
def copy(self):
"Copy the file to the local directory"
fpi= open(self.filename, "rb")
fpo_filename= os.path.join(
self.destination, os.path.basename(self.filename))
try:
fpo= open(fpo_filename, "r+b")
except IOError, exc:
if exc.errno == errno.ENOENT:
fpo= open(fpo_filename, "wb")
else:
raise
try:
self.phase_copy(fpi, fpo, self.phase1, self.phase2)
self.phase_copy(fpi, fpo, self.phase2, self.phase3)
finally:
self.record_state()
def load(batch_size, test_batch_size, n_labelled=None):
filepath = '/tmp/mnist.pkl.gz'
url = 'http://www.iro.umontreal.ca/~lisa/deep/data/mnist/mnist.pkl.gz'
if not os.path.isfile(filepath):
print "Couldn't find MNIST dataset in /tmp, downloading..."
urllib.urlretrieve(url, filepath)
with gzip.open('/tmp/mnist.pkl.gz', 'rb') as f:
train_data, dev_data, test_data = pickle.load(f)
return (
mnist_generator(train_data, batch_size, n_labelled),
mnist_generator(dev_data, test_batch_size, n_labelled),
mnist_generator(test_data, test_batch_size, n_labelled)
)