M cleanlab/datalab/internal/issue_manager/__init__.py => cleanlab/datalab/internal/issue_manager/__init__.py +1 -0
@@ 3,3 3,4 @@ from .duplicate import NearDuplicateIssueManager
from .label import LabelIssueManager
from .outlier import OutlierIssueManager
from .noniid import NonIIDIssueManager
+from .imbalance import ClassImbalanceIssueManager
A cleanlab/datalab/internal/issue_manager/imbalance.py => cleanlab/datalab/internal/issue_manager/imbalance.py +80 -0
@@ 0,0 1,80 @@
+# Copyright (C) 2017-2023 Cleanlab Inc.
+# This file is part of cleanlab.
+#
+# cleanlab is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Affero General Public License as published
+# by the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# cleanlab is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Affero General Public License for more details.
+#
+# You should have received a copy of the GNU Affero General Public License
+# along with cleanlab. If not, see <https://www.gnu.org/licenses/>.
+from __future__ import annotations
+
+from typing import TYPE_CHECKING, ClassVar
+
+import numpy as np
+import pandas as pd
+from cleanlab.datalab.internal.issue_manager import IssueManager
+
+if TYPE_CHECKING: # pragma: no cover
+ from cleanlab.datalab.datalab import Datalab
+
+
+class ClassImbalanceIssueManager(IssueManager):
+ """Manages issues related to imbalance class examples.
+
+ Parameters
+ ----------
+ datalab:
+ The Datalab instance that this issue manager searches for issues in.
+
+ threshold:
+ Minimum fraction of samples of each class that are present in a dataset without class imbalance.
+
+ """
+
+ description: ClassVar[
+ str
+ ] = """Examples belonging to the most under-represented class in the dataset."""
+
+ issue_name: ClassVar[str] = "class_imbalance"
+ verbosity_levels = {
+ 0: [],
+ 1: [],
+ 2: [],
+ }
+
+ def __init__(self, datalab: Datalab, threshold: float = 0.1):
+ super().__init__(datalab)
+ self.threshold = threshold
+
+ def find_issues(
+ self,
+ **kwargs,
+ ) -> None:
+ labels = self.datalab.labels
+ K = len(self.datalab.class_names)
+ class_probs = np.bincount(labels) / len(labels)
+ imbalance_exists = class_probs.min() < self.threshold * (1 / K)
+ rarest_class = int(np.argmin(class_probs)) if imbalance_exists else -1
+ is_issue_column = labels == rarest_class
+ scores = np.where(is_issue_column, class_probs[rarest_class], 1)
+
+ self.issues = pd.DataFrame(
+ {
+ f"is_{self.issue_name}_issue": is_issue_column,
+ self.issue_score_key: scores,
+ },
+ )
+ self.summary = self.make_summary(score=scores.mean())
+ self.info = self.collect_info()
+
+ def collect_info(self) -> dict:
+ params_dict = {"threshold": self.threshold}
+ info_dict = {**params_dict}
+ return info_dict
M cleanlab/datalab/internal/issue_manager_factory.py => cleanlab/datalab/internal/issue_manager_factory.py +2 -0
@@ 46,6 46,7 @@ from cleanlab.datalab.internal.issue_manager import (
NearDuplicateIssueManager,
OutlierIssueManager,
NonIIDIssueManager,
+ ClassImbalanceIssueManager,
)
REGISTRY: Dict[str, Type[IssueManager]] = {
@@ 53,6 54,7 @@ REGISTRY: Dict[str, Type[IssueManager]] = {
"label": LabelIssueManager,
"near_duplicate": NearDuplicateIssueManager,
"non_iid": NonIIDIssueManager,
+ "class_imbalance": ClassImbalanceIssueManager,
}
"""Registry of issue managers that can be constructed from a string
and used in the Datalab class.
M docs/source/cleanlab/datalab/guide/issue_type_description.rst => docs/source/cleanlab/datalab/guide/issue_type_description.rst +25 -3
@@ 23,7 23,7 @@ Datalab produces three estimates for **each** type of issue (called say `<ISSUE_
.. code-block:: python
issue_name = "outlier" # how to reference the outlier issue type in code
- issue_score = "outlier_score" # name of column with quality scores for the outlier issue type, atypical datapoints receive lower scores
+ issue_score = "outlier_score" # name of column with quality scores for the outlier issue type, atypical datapoints receive lower scores
is_issue = "is_outlier_issue" # name of Boolean column flagging which datapoints are considered outliers in the dataset
Datalab estimates various issues based on the four inputs below.
@@ 85,7 85,7 @@ Near duplicated examples may record the same information with different:
- Minor variations which naturally occur in many types of data (e.g. translated versions of an image).
Near Duplicate issues are calculated based on provided `features` or `knn_graph`.
-If you do not provide one of these arguments, this type of issue will not be considered.
+If you do not provide one of these arguments, this type of issue will not be considered.
Datalab defines near duplicates as those examples whose distance to their nearest neighbor (in the space of provided `features`) in the dataset is less than `c * D`, where `0 < c < 1` is a small constant, and `D` is the median (over the full dataset) of such distances between each example and its nearest neighbor.
Scoring the numeric quality of an example in terms of the near duplicate issue type is done proportionally to its distance to its nearest neighbor.
@@ 113,6 113,14 @@ The assumption that examples in a dataset are Independent and Identically Distri
For datasets with low non-IID score, you should consider why your data are not IID and act accordingly. For example, if the data distribution is drifting over time, consider employing a time-based train/test split instead of a random partition. Note that shuffling the data ahead of time will ensure a good non-IID score, but this is not always a fix to the underlying problem (e.g. future deployment data may stem from a different distribution, or you may overlook the fact that examples influence each other). We thus recommend **not** shuffling your data to be able to diagnose this issue if it exists.
+Class-Imbalance Issue
+---------------------
+
+Class imbalance is diagnosed just using the `labels` provided as part of the dataset. The overall class imbalance quality score of a dataset is the proportion of examples belonging to the rarest class `q`. If this proportion `q` falls below a threshold, then we say this dataset suffers from the class imbalance issue.
+
+In a dataset identified as having class imbalance, the class imbalance quality score for each example is set equal to `q` if it is labeled as the rarest class, and is equal to 1 for all other examples.
+
+Class imbalance in a dataset can lead to subpar model performance for the under-represented class. Consider collecting more data from the under-represented class, or at least take special care while modeling via techniques like over/under-sampling, SMOTE, asymmetric class weighting, etc.
Image-specific Issues
---------------------
@@ 132,7 140,7 @@ Appropriate defaults are used for any parameters you do not specify, so no need
.. code-block:: python
possible_issue_types = {
- "label": label_kwargs, "outlier": outlier_kwargs,
+ "label": label_kwargs, "outlier": outlier_kwargs,
"near_duplicate": near_duplicate_kwargs, "non_iid": non_iid_kwargs
}
@@ 225,3 233,17 @@ Non-IID Issue Parameters
.. note::
For more information, view the source code of: :py:class:`datalab.internal.issue_manager.noniid.NonIIDIssueManager <cleanlab.datalab.internal.issue_manager.noniid.NonIIDIssueManager>`.
+
+
+Imbalance Issue Parameters
+--------------------------
+
+.. code-block:: python
+
+ class_imbalance_kwargs = {
+ "threshold": # `threshold` argument to constructor of `ClassImbalanceIssueManager()`. Non-negative floating value between 0 and 1 indicating the minimum fraction of samples of each class that are present in a dataset without class imbalance.
+ }
+
+.. note::
+
+ For more information, view the source code of: :py:class:`datalab.internal.issue_manager.imbalance.ClassImbalanceIssueManager <cleanlab.datalab.internal.issue_manager.imbalance.ClassImbalanceIssueManager>`.
A docs/source/cleanlab/datalab/internal/issue_manager/imbalance.rst => docs/source/cleanlab/datalab/internal/issue_manager/imbalance.rst +9 -0
@@ 0,0 1,9 @@
+imbalance
+=========
+
+
+.. automodule:: cleanlab.datalab.internal.issue_manager.imbalance
+ :autosummary:
+ :members:
+ :undoc-members:
+ :show-inheritance:
M docs/source/cleanlab/datalab/internal/issue_manager/index.rst => docs/source/cleanlab/datalab/internal/issue_manager/index.rst +2 -1
@@ 10,4 10,5 @@ issue_manager
label
outlier
duplicate
- noniid>
\ No newline at end of file
+ noniid
+ imbalance
A tests/datalab/issue_manager/test_imbalance.py => tests/datalab/issue_manager/test_imbalance.py +89 -0
@@ 0,0 1,89 @@
+import numpy as np
+import pytest
+
+from cleanlab.datalab.internal.issue_manager.imbalance import ClassImbalanceIssueManager
+
+SEED = 42
+
+
+class TestClassImbalanceIssueManager:
+ @pytest.fixture
+ def labels(self, lab):
+ K = lab.get_info("statistics")["num_classes"]
+ N = lab.get_info("statistics")["num_examples"] * 20
+ labels = np.random.choice(np.arange(K - 1), size=N, p=[0.5] * (K - 1))
+ labels[0] = K - 1 # Rare class
+ return labels
+
+ @pytest.fixture
+ def create_issue_manager(self, lab, labels, monkeypatch):
+ def manager(labels=labels):
+ monkeypatch.setattr(lab._labels, "labels", labels)
+ return ClassImbalanceIssueManager(datalab=lab, threshold=0.1)
+
+ return manager
+
+ def test_find_issues(self, create_issue_manager, labels):
+ N = len(labels)
+ issue_manager = create_issue_manager()
+ issue_manager.find_issues()
+ issues, summary = issue_manager.issues, issue_manager.summary
+ assert np.sum(issues["is_class_imbalance_issue"]) == 1
+ expected_issue_mask = np.array([True] + [False] * (N - 1))
+ assert np.all(
+ issues["is_class_imbalance_issue"] == expected_issue_mask
+ ), "Issue mask should be correct"
+ expected_scores = np.array([0.01] + [1.0] * (N - 1))
+ np.testing.assert_allclose(
+ issues["class_imbalance_score"], expected_scores, err_msg="Scores should be correct"
+ )
+ assert summary["issue_type"][0] == "class_imbalance"
+ assert summary["score"][0] == pytest.approx(expected=0.9900999, abs=1e-7)
+
+ def test_find_issues_no_imbalance(self, labels, create_issue_manager):
+ N = len(labels)
+ labels[0] = 0
+ issue_manager = create_issue_manager(labels)
+ issue_manager.find_issues()
+ issues, summary = issue_manager.issues, issue_manager.summary
+ assert np.sum(issues["is_class_imbalance_issue"]) == 0
+ assert np.all(
+ issues["is_class_imbalance_issue"] == np.full(N, False)
+ ), "Issue mask should be correct"
+ assert np.all(issues["class_imbalance_score"] == np.ones(N)), "Scores should be correct"
+ assert summary["issue_type"][0] == "class_imbalance"
+ assert summary["score"][0] == pytest.approx(expected=1.0, abs=1e-7)
+
+ def test_find_issues_more_imbalance(self, lab, labels, create_issue_manager):
+ K = lab.get_info("statistics")["num_classes"]
+ N = len(labels)
+ labels[labels == K - 2] = 0
+ labels[1:3] = K - 2
+ issue_manager = create_issue_manager(labels)
+ issue_manager.find_issues()
+ issues, summary = issue_manager.issues, issue_manager.summary
+ assert np.sum(issues["is_class_imbalance_issue"]) == 1
+ expected_issue_mask = np.array([True] + [False] * (N - 1))
+ assert np.all(
+ issues["is_class_imbalance_issue"] == expected_issue_mask
+ ), "Issue mask should be correct"
+ expected_scores = np.array([0.01] + [1.0] * (N - 1))
+ np.testing.assert_allclose(
+ issues["class_imbalance_score"], expected_scores, err_msg="Scores should be correct"
+ )
+ assert summary["issue_type"][0] == "class_imbalance"
+ assert summary["score"][0] == pytest.approx(expected=0.9900999, abs=1e-7)
+
+ def test_report(self, create_issue_manager):
+ issue_manager = create_issue_manager()
+ issue_manager.find_issues()
+ report = issue_manager.report(
+ issues=issue_manager.issues,
+ summary=issue_manager.summary,
+ info=issue_manager.info,
+ )
+ assert isinstance(report, str)
+ assert (
+ "------------------ class_imbalance issues ------------------\n\n"
+ "Number of examples with this issue:"
+ ) in report
M tests/datalab/test_datalab.py => tests/datalab/test_datalab.py +48 -0
@@ 1055,3 1055,51 @@ class TestDatalabWithoutLabels:
# issues_with_labels should have two additional columns about label issues
assert len(issues_without_labels.columns) + 2 == len(issues_with_labels.columns)
pd.testing.assert_frame_equal(issues_without_labels, issues_without_label_name)
+
+
+class TestDataLabClassImbalanceIssues:
+ K = 3
+ N = 100
+ num_features = 2
+
+ @pytest.fixture
+ def random_embeddings(self):
+ np.random.seed(SEED)
+ return np.random.rand(self.N, self.num_features)
+
+ @pytest.fixture
+ def imbalance_labels(self):
+ np.random.seed(SEED)
+ labels = np.random.choice(np.arange(self.K - 1), 100, p=[0.5] * (self.K - 1))
+ labels[0] = 2
+ return labels
+
+ @pytest.fixture
+ def pred_probs(self):
+ np.random.seed(SEED)
+ pred_probs_array = np.random.rand(self.N, self.K)
+ return pred_probs_array / pred_probs_array.sum(axis=1, keepdims=True)
+
+ def test_incremental_search(self, pred_probs, random_embeddings, imbalance_labels):
+ data = {"labels": imbalance_labels}
+ lab = Datalab(data=data, label_name="labels")
+ lab.find_issues(pred_probs=pred_probs, issue_types={"label": {}})
+ summary = lab.get_issue_summary()
+ assert len(summary) == 1
+ assert "class_imbalance" not in summary["issue_type"].values
+ lab.find_issues(features=random_embeddings, issue_types={"class_imbalance": {}})
+ summary = lab.get_issue_summary()
+ assert len(summary) == 2
+ assert "class_imbalance" in summary["issue_type"].values
+ class_imbalance_summary = lab.get_issue_summary("class_imbalance")
+ assert class_imbalance_summary["num_issues"].values[0] > 0
+
+ def test_find_imbalance_issues_no_args(self, imbalance_labels):
+ data = {"labels": imbalance_labels}
+ lab = Datalab(data=data, label_name="labels")
+ lab.find_issues(issue_types={"class_imbalance": {}})
+ summary = lab.get_issue_summary()
+ assert len(summary) == 1
+ assert "class_imbalance" in summary["issue_type"].values
+ class_imbalance_summary = lab.get_issue_summary("class_imbalance")
+ assert class_imbalance_summary["num_issues"].values[0] > 0
M tests/datalab/test_factory.py => tests/datalab/test_factory.py +3 -3
@@ 13,8 13,8 @@ def registry():
def test_list_possible_issue_types(registry):
issue_types = Datalab.list_possible_issue_types()
assert isinstance(issue_types, list)
- defaults = ["label", "outlier", "near_duplicate", "non_iid"]
- assert set(issue_types) == set(defaults)
+ possible_issues = ["label", "outlier", "near_duplicate", "non_iid", "class_imbalance"]
+ assert set(issue_types) == set(possible_issues)
test_key = "test_for_list_possible_issue_types"
@@ 24,7 24,7 @@ def test_list_possible_issue_types(registry):
issue_types = Datalab.list_possible_issue_types()
assert set(issue_types) == set(
- defaults + [test_key]
+ possible_issues + [test_key]
), "New issue type should be added to the list"
# Clean up