stock_DQN.py 文件源码

python
阅读 23 收藏 0 点赞 0 评论 0

项目:pred_finance 作者: jjasonn0717 项目源码 文件源码
def createMultiRNN(self, n_layer, n_hidden):

        with self.sess.graph.as_default():
            self.prob = tf.placeholder("float", name="keep_prob")
            # input #
            with tf.name_scope('input'):
                self.s = tf.placeholder('float', shape=[None, INPUT_DIM, DAYS_RANGE], name='input_state')
                input_trans = tf.transpose(self.s, [2, 0, 1]) # [DAYS_RANGE, None, INPUT_DIM]
                input_reshape = tf.reshape(input_trans, [-1, INPUT_DIM])
                input_list = tf.split(0, DAYS_RANGE, input_reshape) # split to DAY_RANGE element

            with tf.name_scope('tg_input'):
                self.target_s = tf.placeholder('float', shape=[None, INPUT_DIM, DAYS_RANGE], name='input_state')
                tg_input_trans = tf.transpose(self.target_s, [2, 0, 1]) # [DAYS_RANGE, None, INPUT_DIM]
                tg_input_reshape = tf.reshape(tg_input_trans, [-1, INPUT_DIM])
                tg_input_list = tf.split(0, DAYS_RANGE, tg_input_reshape) # split to DAY_RANGE element

            # multi LSTM #
            lstm_cell = rnn_cell.LSTMCell(n_hidden, use_peepholes=True, forget_bias=1.0, state_is_tuple=True)
            lstm_drop = rnn_cell.DropoutWrapper(lstm_cell, output_keep_prob=self.prob)
            lstm_stack = rnn_cell.MultiRNNCell([lstm_drop] * n_layer, state_is_tuple=True)

            tg_lstm_cell = rnn_cell.LSTMCell(n_hidden, use_peepholes=True, forget_bias=1.0, state_is_tuple=True)
            tg_lstm_drop = rnn_cell.DropoutWrapper(tg_lstm_cell, output_keep_prob=self.prob)
            tg_lstm_stack = rnn_cell.MultiRNNCell([tg_lstm_drop] * n_layer, state_is_tuple=True)

            lstm_output, hidden_states = rnn.rnn(lstm_stack,
                                                 input_list,
                                                 dtype='float',
                                                 scope='LSTMStack') # out: [timestep, batch, hidden], state: [cell, 2(for c, h), batch, hidden]
            tg_lstm_output, tg_hidden_states = rnn.rnn(tg_lstm_stack, tg_input_list, dtype='float', scope='tg_LSTMStack')

            for var in tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope="LSTMStack"):
                tf.add_to_collection("L2_VARIABLES", var)

            h_fc1 = self.FC_layer(lstm_output[-1], tg_lstm_output[-1], [n_hidden, 1024], name='h_fc1', activate=True)
            h_fc2 = self.FC_layer(h_fc1[0], h_fc1[1], [1024, ACTIONS], name='h_fc2', activate=False)

            key = tf.GraphKeys.TRAINABLE_VARIABLES
            update_pair = zip(tf.get_collection(key, scope="LSTMStack"), tf.get_collection(key, scope="tg_LSTMStack"))
            for var, tg_var in update_pair:
                self.update_list.append(tg_var.assign(var))

            # readout layer
            self.readout = h_fc2[0]
            self.target_readout = h_fc2[1]
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号