Skip to content

ML Training with Splits and Labels

Note

Requires rasteret[torchgeo].

In TorchGeo, datasets ship with pre-defined train/val/test splits baked into the dataset class. In Rasteret, splits and labels are columns in the Parquet index. You add them yourself, which gives you full control over partitioning strategy and makes the assignments reproducible and shareable.

1. Build a collection

from pathlib import Path
import rasteret

bbox = (77.55, 13.01, 77.58, 13.08)

collection = rasteret.build_from_stac(
    name="bangalore",
    stac_api="https://earth-search.aws.element84.com/v1",
    collection="sentinel-2-l2a",
    bbox=bbox,
    date_range=("2024-01-01", "2024-03-31"),
    workspace_dir=Path.home() / "rasteret_workspace",
)

2. Assign splits

Before filtering by split, the collection needs a split column:

import pyarrow as pa
import numpy as np

table = collection.dataset.to_table()
n = table.num_rows

rng = np.random.default_rng(42)
splits = rng.choice(["train", "val", "test"], size=n, p=[0.7, 0.15, 0.15])
table = table.append_column("split", pa.array(splits))

# Optional: add a label column (e.g. land-cover class per scene)
labels = rng.integers(0, 5, size=n)
table = table.append_column("label", pa.array(labels, type=pa.int32()))

# Save the enriched table and reload as a Collection
import pyarrow.parquet as pq
pq.write_table(table, "./with_splits.parquet")
collection = rasteret.load("./with_splits.parquet")

The split column travels with the Parquet file. Reload later with rasteret.load("./with_splits.parquet") and the splits are preserved.

3. Create TorchGeo datasets per split

split="train" filters the Parquet index before creating the dataset. label_field="label" includes the label column (added in step 2) in each sample as sample["label"].

train_dataset = collection.to_torchgeo_dataset(
    bands=["B02", "B03", "B04", "B08"],
    geometries=bbox,
    split="train",
    label_field="label",
    chip_size=256,
)

val_dataset = collection.to_torchgeo_dataset(
    bands=["B02", "B03", "B04", "B08"],
    geometries=bbox,
    split="val",
    chip_size=256,
)

See to_torchgeo_dataset() API reference.

4. Train

Everything below is standard TorchGeo:

from torch.utils.data import DataLoader
from torchgeo.datasets.utils import stack_samples
from torchgeo.samplers import RandomGeoSampler

sampler = RandomGeoSampler(train_dataset, size=256, length=32)
loader = DataLoader(
    train_dataset,
    sampler=sampler,
    batch_size=4,
    num_workers=0,
    collate_fn=stack_samples,
)

for batch in loader:
    print(f"image: {batch['image'].shape}")
    if "label" in batch:
        print(f"label: {batch['label']}")
    break

The full runnable script is at examples/ml_training_with_splits.py.