~krystianch/cleanlab

3d50b9859d8f2e7a5790121af5817ae403d19b4e — Tata Ganesh 1 year, 1 month ago b0fb8e5
Add imbalance issue to datalab (#758)

Co-authored-by: Jonas Mueller <1390638+jwmueller@users.noreply.github.com>
Co-authored-by: ElĂ­as Snorrason <eliassno@gmail.com>
M cleanlab/datalab/internal/issue_manager/__init__.py => cleanlab/datalab/internal/issue_manager/__init__.py +1 -0
@@ 3,3 3,4 @@ from .duplicate import NearDuplicateIssueManager
from .label import LabelIssueManager
from .outlier import OutlierIssueManager
from .noniid import NonIIDIssueManager
from .imbalance import ClassImbalanceIssueManager

A cleanlab/datalab/internal/issue_manager/imbalance.py => cleanlab/datalab/internal/issue_manager/imbalance.py +80 -0
@@ 0,0 1,80 @@
# Copyright (C) 2017-2023  Cleanlab Inc.
# This file is part of cleanlab.
#
# cleanlab is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published
# by the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# cleanlab is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with cleanlab.  If not, see <https://www.gnu.org/licenses/>.
from __future__ import annotations

from typing import TYPE_CHECKING, ClassVar

import numpy as np
import pandas as pd
from cleanlab.datalab.internal.issue_manager import IssueManager

if TYPE_CHECKING:  # pragma: no cover
    from cleanlab.datalab.datalab import Datalab


class ClassImbalanceIssueManager(IssueManager):
    """Manages issues related to imbalance class examples.

    Parameters
    ----------
    datalab:
        The Datalab instance that this issue manager searches for issues in.

    threshold:
        Minimum fraction of samples of each class that are present in a dataset without class imbalance.

    """

    description: ClassVar[
        str
    ] = """Examples belonging to the most under-represented class in the dataset."""

    issue_name: ClassVar[str] = "class_imbalance"
    verbosity_levels = {
        0: [],
        1: [],
        2: [],
    }

    def __init__(self, datalab: Datalab, threshold: float = 0.1):
        super().__init__(datalab)
        self.threshold = threshold

    def find_issues(
        self,
        **kwargs,
    ) -> None:
        labels = self.datalab.labels
        K = len(self.datalab.class_names)
        class_probs = np.bincount(labels) / len(labels)
        imbalance_exists = class_probs.min() < self.threshold * (1 / K)
        rarest_class = int(np.argmin(class_probs)) if imbalance_exists else -1
        is_issue_column = labels == rarest_class
        scores = np.where(is_issue_column, class_probs[rarest_class], 1)

        self.issues = pd.DataFrame(
            {
                f"is_{self.issue_name}_issue": is_issue_column,
                self.issue_score_key: scores,
            },
        )
        self.summary = self.make_summary(score=scores.mean())
        self.info = self.collect_info()

    def collect_info(self) -> dict:
        params_dict = {"threshold": self.threshold}
        info_dict = {**params_dict}
        return info_dict

M cleanlab/datalab/internal/issue_manager_factory.py => cleanlab/datalab/internal/issue_manager_factory.py +2 -0
@@ 46,6 46,7 @@ from cleanlab.datalab.internal.issue_manager import (
    NearDuplicateIssueManager,
    OutlierIssueManager,
    NonIIDIssueManager,
    ClassImbalanceIssueManager,
)

REGISTRY: Dict[str, Type[IssueManager]] = {


@@ 53,6 54,7 @@ REGISTRY: Dict[str, Type[IssueManager]] = {
    "label": LabelIssueManager,
    "near_duplicate": NearDuplicateIssueManager,
    "non_iid": NonIIDIssueManager,
    "class_imbalance": ClassImbalanceIssueManager,
}
"""Registry of issue managers that can be constructed from a string
and used in the Datalab class.

M docs/source/cleanlab/datalab/guide/issue_type_description.rst => docs/source/cleanlab/datalab/guide/issue_type_description.rst +25 -3
@@ 23,7 23,7 @@ Datalab produces three estimates for **each** type of issue (called say `<ISSUE_
.. code-block:: python

    issue_name = "outlier"  # how to reference the outlier issue type in code
    issue_score = "outlier_score"  # name of column with quality scores for the outlier issue type, atypical datapoints receive lower scores 
    issue_score = "outlier_score"  # name of column with quality scores for the outlier issue type, atypical datapoints receive lower scores
    is_issue = "is_outlier_issue"  # name of Boolean column flagging which datapoints are considered outliers in the dataset

Datalab estimates various issues based on the four inputs below.


@@ 85,7 85,7 @@ Near duplicated examples may record the same information with different:
- Minor variations which naturally occur in many types of data (e.g. translated versions of an image).

Near Duplicate issues are calculated based on provided `features` or `knn_graph`.
If you do not provide one of these arguments, this type of issue will not be considered. 
If you do not provide one of these arguments, this type of issue will not be considered.

Datalab defines near duplicates as those examples whose distance to their nearest neighbor (in the space of provided `features`) in the dataset is less than `c * D`, where `0 < c < 1` is a small constant, and `D` is the median (over the full dataset) of such distances between each example and its nearest neighbor.
Scoring the numeric quality of an example in terms of the near duplicate issue type is done proportionally to its distance to its nearest neighbor.


@@ 113,6 113,14 @@ The assumption that examples in a dataset are Independent and Identically Distri

For datasets with low non-IID score, you should consider why your data are not IID and act accordingly. For example, if the data distribution is drifting over time, consider employing a time-based train/test split instead of a random partition.  Note that shuffling the data ahead of time will ensure a good non-IID score, but this is not always a fix to the underlying problem (e.g. future deployment data may stem from a different distribution, or you may overlook the fact that examples influence each other). We thus recommend **not** shuffling your data to be able to diagnose this issue if it exists.

Class-Imbalance Issue
---------------------

Class imbalance is diagnosed just using the `labels` provided as part of the dataset. The overall class imbalance quality score of a dataset is the proportion of examples belonging to the rarest class `q`. If this proportion `q` falls below a threshold, then we say this dataset suffers from the class imbalance issue.  

In a dataset identified as having class imbalance, the class imbalance quality score for each example is set equal to `q` if it is labeled as the rarest class, and is equal to 1 for all other examples.

Class imbalance in a dataset can lead to subpar model performance for the under-represented class. Consider collecting more data from the under-represented class, or at least take special care while modeling via techniques like over/under-sampling, SMOTE, asymmetric class weighting, etc.

Image-specific Issues
---------------------


@@ 132,7 140,7 @@ Appropriate defaults are used for any parameters you do not specify, so no need 
.. code-block:: python

    possible_issue_types = {
        "label": label_kwargs, "outlier": outlier_kwargs, 
        "label": label_kwargs, "outlier": outlier_kwargs,
        "near_duplicate": near_duplicate_kwargs, "non_iid": non_iid_kwargs
    }



@@ 225,3 233,17 @@ Non-IID Issue Parameters
.. note::

    For more information, view the source code of:  :py:class:`datalab.internal.issue_manager.noniid.NonIIDIssueManager <cleanlab.datalab.internal.issue_manager.noniid.NonIIDIssueManager>`.


Imbalance Issue Parameters
--------------------------

.. code-block:: python

    class_imbalance_kwargs = {
    	"threshold": # `threshold` argument to constructor of `ClassImbalanceIssueManager()`. Non-negative floating value between 0 and 1 indicating the minimum fraction of samples of each class that are present in a dataset without class imbalance.
    }

.. note::

    For more information, view the source code of:  :py:class:`datalab.internal.issue_manager.imbalance.ClassImbalanceIssueManager <cleanlab.datalab.internal.issue_manager.imbalance.ClassImbalanceIssueManager>`.

A docs/source/cleanlab/datalab/internal/issue_manager/imbalance.rst => docs/source/cleanlab/datalab/internal/issue_manager/imbalance.rst +9 -0
@@ 0,0 1,9 @@
imbalance
=========


.. automodule:: cleanlab.datalab.internal.issue_manager.imbalance
    :autosummary:
    :members:
    :undoc-members:
    :show-inheritance:

M docs/source/cleanlab/datalab/internal/issue_manager/index.rst => docs/source/cleanlab/datalab/internal/issue_manager/index.rst +2 -1
@@ 10,4 10,5 @@ issue_manager
    label
    outlier
    duplicate
    noniid
\ No newline at end of file
    noniid
    imbalance

A tests/datalab/issue_manager/test_imbalance.py => tests/datalab/issue_manager/test_imbalance.py +89 -0
@@ 0,0 1,89 @@
import numpy as np
import pytest

from cleanlab.datalab.internal.issue_manager.imbalance import ClassImbalanceIssueManager

SEED = 42


class TestClassImbalanceIssueManager:
    @pytest.fixture
    def labels(self, lab):
        K = lab.get_info("statistics")["num_classes"]
        N = lab.get_info("statistics")["num_examples"] * 20
        labels = np.random.choice(np.arange(K - 1), size=N, p=[0.5] * (K - 1))
        labels[0] = K - 1  # Rare class
        return labels

    @pytest.fixture
    def create_issue_manager(self, lab, labels, monkeypatch):
        def manager(labels=labels):
            monkeypatch.setattr(lab._labels, "labels", labels)
            return ClassImbalanceIssueManager(datalab=lab, threshold=0.1)

        return manager

    def test_find_issues(self, create_issue_manager, labels):
        N = len(labels)
        issue_manager = create_issue_manager()
        issue_manager.find_issues()
        issues, summary = issue_manager.issues, issue_manager.summary
        assert np.sum(issues["is_class_imbalance_issue"]) == 1
        expected_issue_mask = np.array([True] + [False] * (N - 1))
        assert np.all(
            issues["is_class_imbalance_issue"] == expected_issue_mask
        ), "Issue mask should be correct"
        expected_scores = np.array([0.01] + [1.0] * (N - 1))
        np.testing.assert_allclose(
            issues["class_imbalance_score"], expected_scores, err_msg="Scores should be correct"
        )
        assert summary["issue_type"][0] == "class_imbalance"
        assert summary["score"][0] == pytest.approx(expected=0.9900999, abs=1e-7)

    def test_find_issues_no_imbalance(self, labels, create_issue_manager):
        N = len(labels)
        labels[0] = 0
        issue_manager = create_issue_manager(labels)
        issue_manager.find_issues()
        issues, summary = issue_manager.issues, issue_manager.summary
        assert np.sum(issues["is_class_imbalance_issue"]) == 0
        assert np.all(
            issues["is_class_imbalance_issue"] == np.full(N, False)
        ), "Issue mask should be correct"
        assert np.all(issues["class_imbalance_score"] == np.ones(N)), "Scores should be correct"
        assert summary["issue_type"][0] == "class_imbalance"
        assert summary["score"][0] == pytest.approx(expected=1.0, abs=1e-7)

    def test_find_issues_more_imbalance(self, lab, labels, create_issue_manager):
        K = lab.get_info("statistics")["num_classes"]
        N = len(labels)
        labels[labels == K - 2] = 0
        labels[1:3] = K - 2
        issue_manager = create_issue_manager(labels)
        issue_manager.find_issues()
        issues, summary = issue_manager.issues, issue_manager.summary
        assert np.sum(issues["is_class_imbalance_issue"]) == 1
        expected_issue_mask = np.array([True] + [False] * (N - 1))
        assert np.all(
            issues["is_class_imbalance_issue"] == expected_issue_mask
        ), "Issue mask should be correct"
        expected_scores = np.array([0.01] + [1.0] * (N - 1))
        np.testing.assert_allclose(
            issues["class_imbalance_score"], expected_scores, err_msg="Scores should be correct"
        )
        assert summary["issue_type"][0] == "class_imbalance"
        assert summary["score"][0] == pytest.approx(expected=0.9900999, abs=1e-7)

    def test_report(self, create_issue_manager):
        issue_manager = create_issue_manager()
        issue_manager.find_issues()
        report = issue_manager.report(
            issues=issue_manager.issues,
            summary=issue_manager.summary,
            info=issue_manager.info,
        )
        assert isinstance(report, str)
        assert (
            "------------------ class_imbalance issues ------------------\n\n"
            "Number of examples with this issue:"
        ) in report

M tests/datalab/test_datalab.py => tests/datalab/test_datalab.py +48 -0
@@ 1055,3 1055,51 @@ class TestDatalabWithoutLabels:
        # issues_with_labels should have two additional columns about label issues
        assert len(issues_without_labels.columns) + 2 == len(issues_with_labels.columns)
        pd.testing.assert_frame_equal(issues_without_labels, issues_without_label_name)


class TestDataLabClassImbalanceIssues:
    K = 3
    N = 100
    num_features = 2

    @pytest.fixture
    def random_embeddings(self):
        np.random.seed(SEED)
        return np.random.rand(self.N, self.num_features)

    @pytest.fixture
    def imbalance_labels(self):
        np.random.seed(SEED)
        labels = np.random.choice(np.arange(self.K - 1), 100, p=[0.5] * (self.K - 1))
        labels[0] = 2
        return labels

    @pytest.fixture
    def pred_probs(self):
        np.random.seed(SEED)
        pred_probs_array = np.random.rand(self.N, self.K)
        return pred_probs_array / pred_probs_array.sum(axis=1, keepdims=True)

    def test_incremental_search(self, pred_probs, random_embeddings, imbalance_labels):
        data = {"labels": imbalance_labels}
        lab = Datalab(data=data, label_name="labels")
        lab.find_issues(pred_probs=pred_probs, issue_types={"label": {}})
        summary = lab.get_issue_summary()
        assert len(summary) == 1
        assert "class_imbalance" not in summary["issue_type"].values
        lab.find_issues(features=random_embeddings, issue_types={"class_imbalance": {}})
        summary = lab.get_issue_summary()
        assert len(summary) == 2
        assert "class_imbalance" in summary["issue_type"].values
        class_imbalance_summary = lab.get_issue_summary("class_imbalance")
        assert class_imbalance_summary["num_issues"].values[0] > 0

    def test_find_imbalance_issues_no_args(self, imbalance_labels):
        data = {"labels": imbalance_labels}
        lab = Datalab(data=data, label_name="labels")
        lab.find_issues(issue_types={"class_imbalance": {}})
        summary = lab.get_issue_summary()
        assert len(summary) == 1
        assert "class_imbalance" in summary["issue_type"].values
        class_imbalance_summary = lab.get_issue_summary("class_imbalance")
        assert class_imbalance_summary["num_issues"].values[0] > 0

M tests/datalab/test_factory.py => tests/datalab/test_factory.py +3 -3
@@ 13,8 13,8 @@ def registry():
def test_list_possible_issue_types(registry):
    issue_types = Datalab.list_possible_issue_types()
    assert isinstance(issue_types, list)
    defaults = ["label", "outlier", "near_duplicate", "non_iid"]
    assert set(issue_types) == set(defaults)
    possible_issues = ["label", "outlier", "near_duplicate", "non_iid", "class_imbalance"]
    assert set(issue_types) == set(possible_issues)

    test_key = "test_for_list_possible_issue_types"



@@ 24,7 24,7 @@ def test_list_possible_issue_types(registry):

    issue_types = Datalab.list_possible_issue_types()
    assert set(issue_types) == set(
        defaults + [test_key]
        possible_issues + [test_key]
    ), "New issue type should be added to the list"

    # Clean up