rnn_department.py 文件源码

python
阅读 26 收藏 0 点赞 0 评论 0

项目:instacart-basket-prediction 作者: sjvasquez 项目源码 文件源码
def calculate_outputs(self, x):
        h = lstm_layer(x, self.history_length, self.lstm_size, scope='lstm1')
        h = tf.concat([h, x], axis=2)

        self.h_final = time_distributed_dense_layer(h, 50, activation=tf.nn.relu, scope='dense1')
        y_hat = tf.squeeze(time_distributed_dense_layer(self.h_final, 1, activation=tf.nn.sigmoid, scope='dense2'), 2)

        final_temporal_idx = tf.stack([tf.range(tf.shape(self.history_length)[0]), self.history_length - 1], axis=1)
        self.final_states = tf.gather_nd(self.h_final, final_temporal_idx)
        self.final_predictions = tf.gather_nd(y_hat, final_temporal_idx)

        self.prediction_tensors = {
            'user_ids': self.user_id,
            'department_ids': self.department_id,
            'final_states': self.final_states,
            'predictions': self.final_predictions
        }

        return y_hat
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号