Skip to main content
Ctrl+K
MolFlux - Home
  • PyPI

Philosophy

  • What is MolFlux?

Standard API

  • Standardisation

Tutorials

  • ESOL regression
  • ESOL classification
  • ESOL uncertainty
  • QM9 Regression

Datasets

  • Datasets in a nutshell
  • Basic usage
  • More data loading options
  • Save a dataset
  • Featurising
  • Splitting

Features

  • Features in a nutshell
  • Basic usage
  • Add your own representations

Splits

  • Splits in a nutshell
  • Basic usage
  • Add your own splitting strategy
  • Splitting Strategies Gallery

ModelZoo

  • ModelZoo in a nutshell
  • Basic usage
  • Uncertainty for models
  • Add your own model architecture

Metrics

  • Metrics in a nutshell
  • Basic usage
  • Add your own metrics

Productionising

  • Building models for production
  • Featurisation
  • Models
  • Tracking
  • .md

Splitting Strategies Gallery

Contents

  • Visualise our data
  • Define a function to visualize splitting behavior
  • Visualize splitting behaviour for many splitting strategies

Splitting Strategies Gallery#

On this page

  • Visualise our data

  • Define a function to visualize splitting behavior

  • Visualize splitting behaviour for many splitting strategies

285 words | 1 min read

Choosing the right cross-validation object is a crucial part of benchmarking a model properly. There are many ways to split data into training, validation, and test sets in order to avoid model overfitting, to standardize the number of groups in test sets, etc.

This example visualizes the behavior of several common splitting strategies for comparison.

See also

The original sklearn notebook this page is based on.

Show code cell source Hide code cell source
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.patches import Patch

from molflux.splits import load_splitting_strategy

rng = np.random.RandomState(1338)
cmap_data = plt.cm.Paired
cmap_cv = plt.cm.coolwarm
n_splits = 4
figsize = (8, 4)

Visualise our data#

First, we must understand the structure of our data. It has 100 randomly generated input datapoints, 3 classes split unevenly across datapoints, and 10 “groups” split evenly across datapoints.

As we’ll see, some cross-validation objects do specific things with labeled data, others behave differently with grouped data, and others do not use this information.

To begin, we’ll visualize our data:

Show code cell source Hide code cell source
# Generate the class/group data
n_points = 100
X = rng.randn(100, 10)

percentiles_classes = [0.1, 0.3, 0.6]
y = np.hstack([[ii] * int(100 * perc) for ii, perc in enumerate(percentiles_classes)])

# Generate uneven groups
group_prior = rng.dirichlet([2] * 10)
groups = np.repeat(np.arange(10), rng.multinomial(100, group_prior))


def visualize_groups(classes, groups, name):
    # Visualize dataset groups
    fig, ax = plt.subplots(figsize=figsize)
    ax.scatter(
        range(len(groups)),
        [0.5] * len(groups),
        c=groups,
        marker="_",
        lw=50,
        cmap=cmap_data,
    )
    ax.scatter(
        range(len(groups)),
        [3.5] * len(groups),
        c=classes,
        marker="_",
        lw=50,
        cmap=cmap_data,
    )
    ax.set(
        ylim=[-1, 5],
        yticks=[0.5, 3.5],
        yticklabels=["Data\ngroup", "Data\nclass"],
        xlabel="Sample index",
    )


visualize_groups(y, groups, "no groups")
../../_images/a92bd9fc13ff0790af1116a91393ff9d426f602501f12124bc9230e9edc17fcd.png

Define a function to visualize splitting behavior#

We’ll define a function that lets us visualize the behavior of each splitting strategy. We’ll perform 4 splits of the data. On each split, we’ll visualize the indices chosen for the training set (in blue), the validation set (in grey), and the test set (in red).

Show code cell source Hide code cell source
def plot_cv_indices(cv, X, y, group, ax, n_splits, lw=10):
    """Create a sample plot for indices of a cross-validation object."""

    # Generate the training/testing visualizations for each CV split
    for ii, (tr, tv, tt) in enumerate(cv.split(dataset=X, y=y, groups=group, n_splits=n_splits)):
        # Fill in indices with the training/test groups
        indices = np.array([np.nan] * len(X))
        indices[tr] = 0
        indices[tv] = 1
        indices[tt] = 2

        # Visualize the results
        ax.scatter(
            range(len(indices)),
            [ii + 0.5] * len(indices),
            c=indices,
            marker="_",
            lw=lw,
            cmap=cmap_cv,
            vmin=-0.2,
            vmax=2.2,
        )

    # Plot the data classes and groups at the end
    ax.scatter(
        range(len(X)), [ii + 1.5] * len(X), c=y, marker="_", lw=lw, cmap=cmap_data
    )

    ax.scatter(
        range(len(X)), [ii + 2.5] * len(X), c=group, marker="_", lw=lw, cmap=cmap_data
    )

    # Formatting
    yticklabels = list(range(n_splits)) + ["class", "group"]
    ax.set(
        yticks=np.arange(n_splits + 2) + 0.5,
        yticklabels=yticklabels,
        xlabel="Sample index",
        ylabel="CV iteration",
        ylim=[n_splits + 2.2, -0.2],
        xlim=[0, 100],
    )
    ax.set_title("{}".format(cv.tag), fontsize=15)

    ax.legend(
        [Patch(color=cmap_cv(0.02)), Patch(color=cmap_cv(0.5)), Patch(color=cmap_cv(0.8))],
        ["Training set", "Validation set", "Testing set"],
        loc=(1.02, 0.8),
    )

    return ax

Let’s see how it looks for the k_fold cross-validation object:

Show code cell source Hide code cell source
fig, ax = plt.subplots(figsize=figsize)
strategy = load_splitting_strategy("k_fold")
plot_cv_indices(strategy, X, y, groups, ax, n_splits)
/home/runner/work/molflux/molflux/.cache/nox/docs_build-3-11/lib/python3.11/site-packages/sklearn/model_selection/_split.py:91: UserWarning: The groups parameter is ignored by KFold
  warnings.warn(
<Axes: title={'center': 'k_fold'}, xlabel='Sample index', ylabel='CV iteration'>
../../_images/78c37f91d46c2d26b9e91484b23a5786ccf3f6361602c6a6d2f6635b57bd36a3.png

As you can see, by default the k_fold cross-validation iterator does not take either datapoint class or group into consideration. We can change this by using either:

  • stratified_k_fold to preserve the percentage of samples for each class.

  • group_k_fold to ensure that the same group will not appear in two different folds.

Show code cell source Hide code cell source
strategies = ["stratified_k_fold", "group_k_fold"]

for name in strategies:
    fig, ax = plt.subplots(figsize=figsize)
    strategy = load_splitting_strategy(name)
    plot_cv_indices(strategy, X, y, groups, ax, n_splits)
/home/runner/work/molflux/molflux/.cache/nox/docs_build-3-11/lib/python3.11/site-packages/sklearn/model_selection/_split.py:848: UserWarning: The groups parameter is ignored by StratifiedKFold
  warnings.warn(
../../_images/bf51146675a16b71425e8ad93d92a9ff850937a4d4b80f45b7dd55d0a7b45d29.png ../../_images/81c58f04496cd1c092ce4d3e4f99ccb63afc1b77b556d1e1d53ad2ad6a74074a.png

Next we’ll visualize this behavior for a number of splitting iterators.

Visualize splitting behaviour for many splitting strategies#

Let’s visually compare the splitting and cross validation behavior for many of our splitting strategies. Below we will loop through several common strategies, visualizing the behavior of each.

Note how some use the group/class information while others do not:

Show code cell source Hide code cell source
strategies = ["group_k_fold", "group_shuffle_split", "k_fold", "linear_split", "shuffle_split", "stratified_k_fold", "stratified_shuffle_split", "time_series_split"]

for name in strategies:
    fig, ax = plt.subplots(figsize=figsize)
    strategy = load_splitting_strategy(name)
    plot_cv_indices(strategy, X, y, groups, ax, n_splits)
/home/runner/work/molflux/molflux/.cache/nox/docs_build-3-11/lib/python3.11/site-packages/sklearn/model_selection/_split.py:91: UserWarning: The groups parameter is ignored by KFold
  warnings.warn(
/home/runner/work/molflux/molflux/.cache/nox/docs_build-3-11/lib/python3.11/site-packages/sklearn/model_selection/_split.py:848: UserWarning: The groups parameter is ignored by StratifiedKFold
  warnings.warn(
/home/runner/work/molflux/molflux/.cache/nox/docs_build-3-11/lib/python3.11/site-packages/sklearn/model_selection/_split.py:1213: UserWarning: The groups parameter is ignored by TimeSeriesSplit
  warnings.warn(
../../_images/81c58f04496cd1c092ce4d3e4f99ccb63afc1b77b556d1e1d53ad2ad6a74074a.png ../../_images/d2520128002ae239312fbb0170826484bceb88435ee63b39030d813d01c03c69.png ../../_images/78c37f91d46c2d26b9e91484b23a5786ccf3f6361602c6a6d2f6635b57bd36a3.png ../../_images/e2215537d8730ec946d1e110714ad939adff719d3819746889205886d5c14b0a.png ../../_images/e8640af035c1f32777cb1c4cb459673696cb4bacc5b68a7589bb55e2817443ec.png ../../_images/bf51146675a16b71425e8ad93d92a9ff850937a4d4b80f45b7dd55d0a7b45d29.png ../../_images/2b79dacc60dcec0718404b9d902a7fc2debabe84b21121eb0432398bed01fea3.png ../../_images/2a68e32f9783a10d980a425e3b906f95789796dfbb33860c31aebc1aa8ea721f.png

previous

Add your own splitting strategy

next

ModelZoo in a nutshell

Contents
  • Visualise our data
  • Define a function to visualize splitting behavior
  • Visualize splitting behaviour for many splitting strategies

By Exscientia