def __init__(self, symbols,
start_date=dt.datetime(2008,1,1), #CHANGE DATES TO ACTUAL RANGE!!!
end_date= dt.datetime(2009,1,1)):
# frame a time period as world
self.dates_range = pd.date_range(start_date, end_date)
# initialize cash holdings
init_cash = 100000
#for visualization
self.data_out = []
# preprocessing time series
# stock symbol data
stock_symbols = symbols[:]
symbols.append('interest_rates')
symbols.append('vix')
# price data
prices_all = util.get_data(symbols, self.dates_range, True)
self.stock_A = stock_symbols[0]
self.stock_B = stock_symbols[1]
"""
#unemployment rate
temp_unemp = {}
unemployment = {}
with open('unemployment.csv') as unemp_file:
for line in csv.reader(unemp_file, delimiter=','):
curr_date = dt.strptime(line[0], '%B-%y')
temp_unemp[curr_date] = line[1]
for d in prices_all.keys():
temp_date = dt.datetime(d.year, d.month)
if temp_date in temp_unemp:
unemployment[d] = temp_unemp[temp_date]
"""
# first trading day
self.dateIdx = 0
self.date = prices_all.index[0]
self.start_date = start_date
self.end_date = end_date
self.prices = prices_all[stock_symbols]
self.prices_SPY = prices_all['spy']
self.prices_VIX = prices_all['vix']
self.prices_interest_rate = prices_all['interest_rates']
# keep track of portfolio value as a series
self.portfolio = {'cash': init_cash, 'a_vol': [], 'a_price': [], 'b_vol': [], 'b_price': [], 'longA': 0}
self.port_val = self.port_value_for_output()
# hardcode enumerating of features
"""
self.sma = SMA(self.dates_range)
self.bbp = BBP(self.dates_range)
self.rsi = RSI(self.dates_range)
"""
environment.py 文件源码
python
阅读 24
收藏 0
点赞 0
评论 0
评论列表
文章目录