def __init__(self, cas, app_name, api_host, api_base, shiro_cas_base):
""" Constructor
Args:
cas (object): A valid CAS instance
app_name (str): The application name
api_host (str): The host to make API request to
api_base (str): The API base of the application
shiro_cas_base (str): The shiro cas base of the application
"""
assert cas and isinstance(cas, CAS)
self._cas = cas
assert app_name
self._app_name = app_name
assert api_host
self._api_host = api_host
assert api_base
self._api_base = api_base
assert shiro_cas_base
self._shiro_cas_base = shiro_cas_base
self._st = None
self._logger = get_logger()
python类get_logger()的实例源码
def task(self):
get_logger().info(f"Feeding Router intraday average_prices from database")
for table in self.intraday_tables:
try:
query = f"SELECT symbol, price FROM `{table}` ORDER BY datetime"
cursor = database.connect('yquant_intraday').cursor(SSCursor)
cursor.execute(query)
get_logger().info(f"Sending average_prices on {table} to Router")
for row in cursor:
# get_logger().debug(f"{row[0]} {row[1]}")
self.sockets['Router'].send_string(f'PRICE {row[0]} {row[1]}')
except pymysql.err.ProgrammingError as e:
if "doesn't exist" in str(e):
get_logger().error(f"Table [{table}] doesn't exist")
self.sockets['Router'].send_string(f'END')
def __init__(self):
self._loop = None # asyncio.get_event_loop()
self._logger = None # logging.getLogger()
self._config = None # json.load(open('config.json'))
self.inserted_rows = 0 # to be added itself after inserting
self.updated_rows = 0 # to be added itself after updating
self._sockets = {} # dictionary of name:zmq_socket
self._symbols = defaultdict(lambda: 'UNKNOWN') # dictionary of symbol:name
get_logger().debug(f"Starting {self.__class__.__name__} process <{os.getpid()}> ---------------------")
# message handlers (avoid using getattr() for performance)
self.handlers = { 'END': self.handler_end }
if self.is_backtest: get_logger().debug("Backtest mode is activated")
def task(self):
# This will be converted as a coroutine by loop.create_task()
# then executed by loop.run_forever() automatically in ModuleBase.run()
# until self.loop.stop() called
get_logger().debug(f"Waiting a message from {len(self.sockets)-1} socket(s)")
while True:
message = self.sockets['listen'].recv_string() # ex) PRICE 015760 45000
get_logger().debug(f"Received message [{message}]")
command = message.split()[0] # command: PRICE
args = message.split()[1:] # args: 015760 45000
try:
handler = self.handlers[command]
except KeyError: # might occurred when handler is not prepared
get_logger().error(f"No handler for command [{command}]")
continue
if await handler(*args) is False: # differs from 'if not handler(*args)'
break
def test_logging(self):
self.assertTrue(not os.path.exists('test.log'))
logger.initialize_logging('test.log', max_len=1024, interactive=True)
l = logger.get_logger()
self.assertTrue(os.path.exists('test.log'))
self.assertTrue(not os.path.exists('test.log.1'))
l.debug('debug msg')
l.info('info msg')
l.warn('warn')
l.error('error')
l.info('d'*1024) # force a rollover
l.info('new file')
self.assertTrue(os.path.exists('test.log'))
self.assertTrue(os.path.exists('test.log.1'))
def test_logging(self):
self.assertTrue(not os.path.exists('test.log'))
logger.initialize_logging('test.log', max_len=1024, interactive=True)
l = logger.get_logger()
self.assertTrue(os.path.exists('test.log'))
self.assertTrue(not os.path.exists('test.log.1'))
l.debug('debug msg')
l.info('info msg')
l.warn('warn')
l.error('error')
l.info('d'*1024) # force a rollover
l.info('new file')
self.assertTrue(os.path.exists('test.log'))
self.assertTrue(os.path.exists('test.log.1'))
def __init__(self, cas_group, cas_username, cas_password, cas_host=None, secure=True, loglevel=logging.INFO):
""" Constructor
Args:
cas_group (str): The group to authenticate with CAS
cas_username (str): The username to authenticate with CAS
cas_password (str): The password to authenticate with CAS
cas_host (str): The host/IP of the CAS server
secure (bool): Enable the certificate check or not
loglevel (int): Log level
"""
assert cas_group and len(cas_group) > 0
self._cas_group = cas_group
assert cas_username and len(cas_username) > 0
self._cas_username = cas_username
assert cas_password and len(cas_password) > 0
self._cas_password = cas_password
self._cas_host = self.__sso_cas_host if not cas_host else cas_host
self._tgt = None
self._logger = get_logger(loglevel)
# Init urllib2
self._init_urllib(secure)
self._logger.debug("CAS object (%s,%s,%s,%s) constructed" % (cas_host, cas_group, cas_username, cas_password))
def __init__(self, log="simplescraper.log"):
from logger import get_logger
logger = get_logger(log, maxbytes=2147483648)
Connect.__init__(self, logger)
def __del__(self):
get_logger().debug(f"Finalizing {self.__class__.__name__} process <{os.getpid()}>")
if self._sockets:
for key, socket in self.sockets.items():
get_logger().debug(f"Closing {key} socket")
socket.close()
if self._loop:
get_logger().debug(f"Closing asyncio event loop")
self._loop.close()
def run(self):
try:
self.prepare()
self.loop.create_task(self.task())
self.loop.create_task(self.finish()) # to finish the event loop when all tasks are done
self.loop.run_forever()
except KeyboardInterrupt:
get_logger().debug("Stopping the event loop by a keyboard interrupt")
self.loop.stop()
except Exception as e:
get_logger().exception(e)
raise
def handler_end(self):
get_logger().info("Received [END] message")
message = "END"
for key, socket in self.sockets.items():
if key != 'listen':
socket.send_string(message)
get_logger().debug(f"Sent [{message}] message to {key} socket")
return False
def loop(self):
if not self._loop:
get_logger().debug("Starting asyncio event loop")
self._loop = asyncio.get_event_loop()
return self._loop
def config(self):
if not self._config:
# get_logger().debug("Loading configuration file")
dir = os.path.dirname(os.path.realpath(__file__))
with open(os.path.join(dir, 'config.json')) as file:
self._config = json.load(file)
return self._config
def sockets(self):
if not self._sockets:
try:
address = self.config[self.__class__.__name__]['listen']
get_logger().debug(f"Binding {self.__class__.__name__} socket <{address}>")
self._sockets['listen'] = zmq.Context().socket(zmq.SUB)
self._sockets['listen'].setsockopt_string(zmq.SUBSCRIBE, '')
self._sockets['listen'].bind(address)
except KeyError: # some modules may not need to listen to
pass
except zmq.error.ZMQError as e:
if "Address in use" in str(e):
get_logger().error(f"Binding failed because the address {address} in use")
get_logger().error(f"Terminate the previous process using the address above")
self.loop.stop()
return
try:
next_modules = self.config[self.__class__.__name__]['next']
if isinstance(next_modules, str):
next_modules = [next_modules,]
for key in next_modules:
address = self.config[key]['listen']
get_logger().debug(f"Connecting to {key} socket <{address}>")
self._sockets[key] = zmq.Context().socket(zmq.PUB)
self._sockets[key].connect(address)
time.sleep(0.001) # ensuring connection made correctly
except KeyError: # some modules may not need to send messages
pass
return self._sockets
def main(framework, train_main, generate_main):
arg_parser = ArgumentParser(
description="{} character embeddings LSTM text generation model.".format(framework))
subparsers = arg_parser.add_subparsers(title="subcommands")
# train args
train_parser = subparsers.add_parser("train", help="train model on text file")
train_parser.add_argument("--checkpoint-path", required=True,
help="path to save or load model checkpoints (required)")
train_parser.add_argument("--text-path", required=True,
help="path of text file for training (required)")
train_parser.add_argument("--restore", nargs="?", default=False, const=True,
help="whether to restore from checkpoint_path "
"or from another path if specified")
train_parser.add_argument("--seq-len", type=int, default=64,
help="sequence length of inputs and outputs (default: %(default)s)")
train_parser.add_argument("--embedding-size", type=int, default=32,
help="character embedding size (default: %(default)s)")
train_parser.add_argument("--rnn-size", type=int, default=128,
help="size of rnn cell (default: %(default)s)")
train_parser.add_argument("--num-layers", type=int, default=2,
help="number of rnn layers (default: %(default)s)")
train_parser.add_argument("--drop-rate", type=float, default=0.,
help="dropout rate for rnn layers (default: %(default)s)")
train_parser.add_argument("--learning-rate", type=float, default=0.001,
help="learning rate (default: %(default)s)")
train_parser.add_argument("--clip-norm", type=float, default=5.,
help="max norm to clip gradient (default: %(default)s)")
train_parser.add_argument("--batch-size", type=int, default=64,
help="training batch size (default: %(default)s)")
train_parser.add_argument("--num-epochs", type=int, default=32,
help="number of epochs for training (default: %(default)s)")
train_parser.add_argument("--log-path", default=os.path.join(os.path.dirname(__file__), "main.log"),
help="path of log file (default: %(default)s)")
train_parser.set_defaults(main=train_main)
# generate args
generate_parser = subparsers.add_parser("generate", help="generate text from trained model")
generate_parser.add_argument("--checkpoint-path", required=True,
help="path to load model checkpoints (required)")
group = generate_parser.add_mutually_exclusive_group(required=True)
group.add_argument("--text-path", help="path of text file to generate seed")
group.add_argument("--seed", default=None, help="seed character sequence")
generate_parser.add_argument("--length", type=int, default=1024,
help="length of character sequence to generate (default: %(default)s)")
generate_parser.add_argument("--top-n", type=int, default=3,
help="number of top choices to sample (default: %(default)s)")
generate_parser.add_argument("--log-path", default=os.path.join(os.path.dirname(__file__), "main.log"),
help="path of log file (default: %(default)s)")
generate_parser.set_defaults(main=generate_main)
args = arg_parser.parse_args()
get_logger("__main__", log_path=args.log_path, console=True)
logger = get_logger(__name__, log_path=args.log_path, console=True)
logger.debug("call: %s", " ".join(sys.argv))
logger.debug("ArgumentParser: %s", args)
try:
args.main(args)
except Exception as e:
logger.exception(e)