environment.py 文件源码

python
阅读 24 收藏 0 点赞 0 评论 0

项目:Deep-Reinforcement-Learning-in-Stock-Trading 作者: shenyichen105 项目源码 文件源码
def __init__(self, symbols,
        start_date=dt.datetime(2008,1,1),                           #CHANGE DATES TO ACTUAL RANGE!!!
        end_date= dt.datetime(2009,1,1)):

        # frame a time period as world
        self.dates_range = pd.date_range(start_date, end_date)

        # initialize cash holdings
        init_cash = 100000

        #for visualization
        self.data_out = []

        # preprocessing time series
        # stock symbol data
        stock_symbols = symbols[:]
        symbols.append('interest_rates')
        symbols.append('vix')
        # price data
        prices_all = util.get_data(symbols, self.dates_range, True)

        self.stock_A = stock_symbols[0]
        self.stock_B = stock_symbols[1]

        """
        #unemployment rate
        temp_unemp = {}
        unemployment = {}
        with open('unemployment.csv') as unemp_file:
            for line in csv.reader(unemp_file, delimiter=','):
                curr_date = dt.strptime(line[0], '%B-%y')
                temp_unemp[curr_date] = line[1]
        for d in prices_all.keys():
            temp_date = dt.datetime(d.year, d.month)
            if temp_date in temp_unemp:
                unemployment[d] = temp_unemp[temp_date]
        """

        # first trading day
        self.dateIdx = 0
        self.date = prices_all.index[0]
        self.start_date = start_date
        self.end_date = end_date

        self.prices = prices_all[stock_symbols]
        self.prices_SPY = prices_all['spy']
        self.prices_VIX = prices_all['vix']
        self.prices_interest_rate = prices_all['interest_rates']

        # keep track of portfolio value as a series
        self.portfolio = {'cash': init_cash, 'a_vol': [], 'a_price': [], 'b_vol': [], 'b_price': [], 'longA': 0}
        self.port_val = self.port_value_for_output()

        # hardcode enumerating of features
        """
        self.sma = SMA(self.dates_range)
        self.bbp = BBP(self.dates_range)
        self.rsi = RSI(self.dates_range)
        """
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号