def calculate_assignments(self, assignment_weights): clusters = np.argmax(assignment_weights, axis=1) return clusters