teststrategy.py 文件源码

python
阅读 17 收藏 0 点赞 0 评论 0

项目:ReinforcementL_trading 作者: zhangbppku8663 项目源码 文件源码
def query_model(sym=['AAPL'], sd=dt.datetime(2017,5,1), ed=dt.datetime(2017,7,1),
                holdings=None, inds_list=['bbp','ATR'], div_file='Dividers.csv', QT_file='Q_Table.csv'):

    data = get_data(sym, dates=pd.date_range(sd - dt.timedelta(days=35), ed))
    data = add_bband(data)
    data = add_ATR(data)

    # make a divider dictionary from saved csv file
    try:
        divider = pd.read_csv(div_file, index_col=0)
    except IOError as e:
        print(e)
    div_dict = {};
    for ind in inds_list:
        div_dict[ind] = divider[ind].tolist()

    # create a StrategyLearner just to get states
    sLeaner = sl.StrategyLearner()
    sLeaner.div_dict = div_dict
    indicators = data[inds_list].dropna(how='all')
    states = sLeaner._get_state(indicators)
    # slicing out only required date range
    # able to deal with non-trading sd
    pass_day = 0
    while sd not in indicators.index.tolist():
        sd = sd + dt.timedelta(days=1)
        pass_day += 1
        if sd > indicators.index[-1]:
            print('something wrong with the start date')
            break
    start_index = indicators.index.get_loc(sd)
    states = states[start_index:]

    if holdings is None:
        states = states + 100  # in this two indicator case, assume no holdings
    else:
        try:
            new_holdings = holdings[pass_day:]
            for i, hold in zip(range(len(states)), new_holdings):
                states[i] = states[i] + hold
        except:
            print('may have different length of holding information in this case')

    try:
        Q_table = pd.read_csv(QT_file, index_col=0)
    except IOError as e:
        print(e)
    Q_table = np.matrix(Q_table)

    qLearner = ql.QLearner(rar=0)  # no random choice
    qLearner.Q_table = Q_table
    look_up = {0: 'SELL', 1: 'NOTHING', 2: 'BUY'}
    suggestions = []
    for state in states:
        suggestions.append(look_up[qLearner.querysetstate(state)])
    effect_dates = indicators.index[start_index:]
    guide_df = pd.DataFrame(suggestions, index=effect_dates, columns=['{}'.format(sym[0])])

    return guide_df
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号