Finding Label Issues in Image Classification Datasets

04/21/2022
  • Wei Jing LokWei Jing Lok
  • Jonas MuellerJonas Mueller

Supervised machine learning often assumes that the labels we train our model on are correct. ML models have progressed remarkably within this paradigm, exceeding 99% accuracy for predicting the given labels in the classic MNIST image classification dataset of handwritten digits. If our data are labeled correctly, this is excellent news, but recent studies have discovered that even highly-curated ML benchmark datasets are full of label errors. Furthermore, the labels in datasets from real-world applications can be of far lower quality. As the adage “garbage in, garbage out” cautions, it seems problematic to teach our ML models to predict fundamentally flawed labels. Even worse, we might train and evaluate these models with flawed labels and deploy the resulting models in systems that affect our daily lives.

Luckily there are tools to help prevent this! The open-source cleanlab library helps you find all the label issues lurking in your data. In this hands-on blog, we’ll use cleanlab to find label issues in the MNIST dataset, which has been cited over 40,000 times. This dataset contains 70,000 images of handwritten digits from 0 to 9 (these are the labels to predict). Here are some of the MNIST label issues found by cleanlab:

Examples of label errors in the MNIST dataset

This post will show how to run cleanlab to find these issues in the MNIST dataset. You can use the same cleanlab workflow demonstrated here to easily find bad labels in your dataset. You can run this workflow yourself in under 5 minutes:


Overview of the steps to find label issues with cleanlab

In this post, we will walk through the following steps:

  1. Build a simple PyTorch neural network model to classify digits.

  2. Use this model to compute out-of-sample predicted probabilities, pred_probs, via cross-validation.

  3. Identify potential label errors in one line of code with cleanlab’s find_label_issues() method.

The rest of this blog dives into the code implementing this workflow.

Show me the code

We'll start by installing and importing some required packages (click to see code)

You can use pip to install the dependencies for this workflow:

pip install cleanlab pandas matplotlib torch torchvision skorch

Our first few Python commands will import some of the required packages, set some configurations for better-looking output, and set seeds for reproducibility.

import numpy as np
import torch
import warnings

SEED = 123
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
torch.cuda.manual_seed_all(SEED)
warnings.filterwarnings("ignore")

Prepare the dataset

The MNIST dataset can be fetched directly from OpenML.

from sklearn.datasets import fetch_openml

mnist = fetch_openml("mnist_784")

Once we have the data, let’s assign the image features to a variable X, which OpenML gives us as a 2D array. Each row in this original array represents an image with its features flattened as a 1D row vector. Each entry in this vector has a value ranging from 0 to 255, representing the intensity of a particular pixel in the image (MNIST images are in grayscale).

X = mnist.data.astype("float32").to_numpy()

To apply convolutional neural networks to our features X, we’ll normalize their values to the [0,1] interval, and then reshape each element of X to a 2D image with: height, width, and a channel dimension that happens to be 1 since MNIST images are grayscale rather than RGB.

X /= 255.0  # Scale the features to the [0, 1] range
X = X.reshape(len(X), 1, 28, 28)

print(X[50])  # depict an arbitrary datapoint, change 50 to another value to see other datapoints

Note that if we were to model this data with a fully-connected neural network, this reshape would be unnecessary since fully-connected networks typically operate on 1D vectors.

Now that we have prepared the image features let’s also assign the image labels to a variable y, stored as a 1D NumPy array. The labels are values in {0,…,9} corresponding to the digit depicted in each image.

y = mnist.target.astype("int64").to_numpy()

print(y[:50])  # show first 50 labels
[5 0 4 1 9 2 1 3 1 4 3 5 3 6 1 7 2 8 6 9 4 0 9 1 1 2 4 3 2 7 3 8 6 9 0 5 6 0 7 6 1 8 7 9 3 9 8 5 9 3]

To find label issues in your image dataset instead of MNIST, you simply need to assign your data’s features to variable X and its labels to variable y instead. Then you can apply the rest of the workflow from this post directly as is!

Define an image classification model

Our supervised learning task is to classify which digit is depicted in an image. We’ll use a simple convolutional neural network with PyTorch (ConvNet).

from torch import nn

class ClassifierModule(nn.Module):
    def __init__(self):
        super().__init__()

        self.cnn = nn.Sequential(
            nn.Conv2d(1, 6, 3),
            nn.ReLU(),
            nn.BatchNorm2d(6),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(6, 16, 3),
            nn.ReLU(),
            nn.BatchNorm2d(16),
            nn.MaxPool2d(kernel_size=2, stride=2),
        )
        self.out = nn.Sequential(
            nn.Flatten(),
            nn.LazyLinear(128),
            nn.ReLU(),
            nn.Linear(128, 10),
            nn.Softmax(dim=-1),
        )

    def forward(self, X):
        X = self.cnn(X)
        X = self.out(X)
        return X

Our model has just two convolutional layers and two fully connected layers.

Feel free to replace the above network with any PyTorch model: cleanlab can actually be used with any classifier!

As some cleanlab features leverage scikit-learn compatibility, we’ll wrap the above PyTorch neural net using skorch, which instantly makes it scikit-learn-compatible. Below, we’ve set max_epochs=50, but this is left to the default of 10 in the Colab Notebook to speed up the execution time. Feel free to lower this value to get results faster or increase it to get better results.

from skorch import NeuralNetClassifier

model_skorch = NeuralNetClassifier(ClassifierModule, max_epochs=50)

skorch also offers many other quality of life improvements that simplify deep learning with PyTorch. Alternatively, you can also easily wrap an arbitrary model to be scikit-learn compatible, as demonstrated here.

Compute out-of-sample predicted probabilities

Generally speaking, cleanlab uses predictions from a trained classifier to identify label issues in the data. More specifically, these predictions should be the classifier’s estimate of the conditional probability of each class for a specific example (c.f. sklearn.predict_proba). Check out this blog post on Confident Learning to learn more about the algorithm cleanlab uses to find label issues.

Typically we’ll want cleanlab to find label issues in all of our data. However, if we train a classifier on some of this data, then its predictions on the same data will become untrustworthy due to overfitting. To resolve this, we will train our classifier with K-fold cross-validation which enables us to get the out-of-sample predicted probabilities for each and every example in the dataset. These are predictions from a copy of the classifier trained on a dataset that does not contain this example, hence the model is less likely to be overfitted. The cross_val_predict method used below enables you to easily generate out-of-sample predicted probabilities from any scikit-learn-compatible model.

from sklearn.model_selection import cross_val_predict

num_crossval_folds = 5
pred_probs = cross_val_predict(model_skorch, X, y,
                               cv=num_crossval_folds,
                               method='predict_proba')

Another benefit of cross-validation is that it facilitates more reliable evaluation of our model than a single training/validation split. Here’s how to estimate the accuracy of the model trained via cross-validation.

from sklearn.metrics import accuracy_score

predicted_labels = pred_probs.argmax(axis=1)
acc = accuracy_score(y, predicted_labels)

print(f"Cross-validated estimate of accuracy on held-out data: {acc}")
Cross-validated estimate of accuracy on held-out data: 0.9846285714285714

Models with higher accuracy tend to find label errors better when used with cleanlab. Thus we should always try to ensure that our model is reasonably performant.

Use cleanlab to find label issues

Based on the given labels and out-of-sample predicted probabilities, cleanlab can identify label issues in our dataset in one line of code.

from cleanlab.filter import find_label_issues

ranked_label_issues = find_label_issues(y, pred_probs, return_indices_ranked_by="self_confidence")

print(f"Cleanlab found {len(ranked_label_issues)} label issues.")
print("Here are the indices of the top 15 most likely label errors:\n"
      f"{ranked_label_issues[:15]}")
Cleanlab found 149 label issues.
Here are the indices of the top 15 most likely label errors:
[59915 24798 26405  8729 30049 34404 32342 43454 53216 15434 33388 10994 34920 39184  2720]

ranked_label_issues is a list of indices corresponding to examples whose labels warrant a closer inspection. Above, we requested these indices to be sorted by cleanlab’s self-confidence label quality score, which measures the quality of each given label via the probability assigned to it in our model’s prediction.

To help visualize specific examples, we define a function: plot_examples() (click to see code)
import matplotlib.pyplot as plt

def plot_examples(id_iter, nrows=1, ncols=1):
    for count, id in enumerate(id_iter):
        plt.subplot(nrows, ncols, count + 1)
        plt.imshow(X[id].reshape(28, 28), cmap="gray")
        plt.title(f"id: {id} \n label: {y[id]}")
        plt.axis("off")

    plt.tight_layout(h_pad=2.0)

Let’s look at the top 15 examples that cleanlab thinks are most likely to be incorrectly labeled. We can see a few label errors and odd edge cases.

plot_examples(ranked_label_issues[range(15)], 3, 5)

Let’s zoom into some specific examples from the above set:

Given label is 4 but is actually a 7:

plot_examples([59915])

Given label is 4 but also looks like 9:

plot_examples([24798])

Given label is 5 but is actually a 3:

plot_examples([43454])

Given label is 3 but is actually a 9:

plot_examples([10994])

Using mislabeled examples like these to train/evaluate ML models may be a questionable idea!

Conclusion

cleanlab has shortlisted the most likely label errors to speed up your data cleaning process. With this list, you can decide whether to manually correct the labels or remove some of these examples from your dataset.

You can see that even widely-used datasets like MNIST contain problematic labels. Never blindly trust your data! You should always check it for potential issues, many of which can be easily identified by cleanlab. While this post studied MNIST image classification with PyTorch neural networks, cleanlab can be easily used for any dataset (image, text, tabular, etc.) and classification model.

While cleanlab helps you automatically find data issues, an interface is needed to efficiently fix these issues your dataset. Cleanlab Studio finds and fixes errors automatically in a (very cool) no-code platform. Export your corrected dataset in a single click to train better ML models on better data. Try Cleanlab Studio at https://app.cleanlab.ai/.

cleanlab is undergoing active development, and we’re always interested in more open-source contributors!

If you want to stay up-to-date on the latest developments from the Cleanlab team, please:

Additional References