Skip to content

Multi-Dataset Training

Use this page when you want to combine multiple Rasteret-backed TorchGeo datasets in one training workflow.

Each collection can become a standard TorchGeo GeoDataset:

s2 = s2_collection.to_torchgeo_dataset(
    bands=["B04", "B03", "B02"],
    chip_size=256,
)

mask = mask_collection.to_torchgeo_dataset(
    bands=["mask"],
    chip_size=256,
    is_image=False,
)

TorchGeo handles dataset composition with & and |.

Intersection With &

Use & when a sample should come from areas where both datasets have coverage. This is common for imagery plus masks, imagery plus embeddings, or imagery plus another aligned raster source.

from torch.utils.data import DataLoader
from torchgeo.datasets.utils import stack_samples
from torchgeo.samplers import RandomGeoSampler

training = s2 & mask

sampler = RandomGeoSampler(training, size=256, length=100)
loader = DataLoader(
    training,
    sampler=sampler,
    batch_size=4,
    collate_fn=stack_samples,
)

for batch in loader:
    image = batch["image"]
    target = batch["mask"]
    break

TorchGeo's IntersectionDataset computes the spatial and temporal overlap. By default, when both datasets return the same key such as image, TorchGeo stacks the arrays along the channel dimension. If you want separate keys, create one dataset with is_image=False so it returns sample["mask"].

Union With |

Use | when a sample can come from either dataset's coverage area:

s2 = s2_collection.to_torchgeo_dataset(
    bands=["B04", "B03", "B02"],
    chip_size=256,
)
landsat = landsat_collection.to_torchgeo_dataset(
    bands=["B04", "B03", "B02"],
    chip_size=256,
)

training = s2 | landsat

TorchGeo's UnionDataset concatenates the spatial index and tries each dataset for a requested sample. When multiple datasets can satisfy the same sample, its default collation merges the returned sample dictionaries.

CRS And Resolution

TorchGeo composition aligns the second dataset to the first dataset's CRS and resolution metadata. If your Rasteret collections span multiple raster CRS zones, pass target_crs=... when creating each dataset:

s2 = s2_collection.to_torchgeo_dataset(
    bands=["B04", "B03", "B02"],
    chip_size=256,
    target_crs=32610,
)

aef = aef_collection.to_torchgeo_dataset(
    bands=["A00", "A01"],
    chip_size=256,
    target_crs=32610,
)

xarray Path

For analysis workflows, read each collection separately and combine with xarray:

import xarray as xr

ds_s2 = s2_collection.get_xarray(geometries=aoi, bands=["B04", "B08"])
ds_aef = aef_collection.get_xarray(geometries=aoi, bands=["A00"])
combined = xr.merge([ds_s2, ds_aef])

Use target_crs=... on the read calls when the collections are in different CRS zones.