"""Metadata and data loading model classes."""
from __future__ import annotations
import datetime as dt
import math
from copy import copy
from dataclasses import astuple, dataclass, field, replace
from typing import (
Any,
Callable,
Dict,
Iterator,
List,
Mapping,
Optional,
Sequence,
Set,
Tuple,
)
import numpy as np
from odc.geo import CRS, Geometry, MaybeCRS
from odc.geo.geobox import GeoBox
from odc.geo.types import Unset
from odc.loader.types import (
AuxBandMetadata,
AuxDataSource,
BandIdentifier,
BandKey,
BandQuery,
FixedCoord,
RasterBandMetadata,
RasterGroupMetadata,
RasterSource,
norm_band_metadata,
norm_key,
)
from typing_extensions import override
[docs]
@dataclass(eq=True, frozen=True)
class ParsedItem(Mapping[BandIdentifier, RasterSource | AuxDataSource]):
"""
Captures essentials parts for data loading from a STAC Item.
Only includes raster bands of interest.
"""
# pylint: disable=too-many-instance-attributes
id: str
"""Item id copied from STAC."""
collection: RasterCollectionMetadata
"""Collection this Item is part of."""
bands: Mapping[BandKey, RasterSource | AuxDataSource]
"""Raster bands."""
geometry: Optional[Geometry] = None
"""Footprint of the dataset."""
datetime: Optional[dt.datetime] = None
"""Nominal timestamp."""
datetime_range: Tuple[Optional[dt.datetime], Optional[dt.datetime]] = None, None
"""Time period covered."""
href: Optional[str] = None
"""Self link from stac item."""
accessories: dict[str, Any] = field(default_factory=dict)
"""Additional assets"""
[docs]
def geoboxes(self, bands: BandQuery = None) -> Tuple[GeoBox, ...]:
"""
Unique ``GeoBox`` s, highest resolution first.
:param bands: which bands to consider, default is all
"""
bands = self.collection.normalize_band_query(bands)
def _resolution(g: GeoBox) -> float:
return min(g.resolution.map(abs).xy)
# TODO: support other geobox types?
gbx: Set[GeoBox] = set()
for name in bands:
b = self.bands.get(self.collection.band_key(name), None)
if isinstance(b, RasterSource):
if b.geobox is not None:
assert isinstance(b.geobox, GeoBox)
gbx.add(b.geobox)
return tuple(sorted(gbx, key=_resolution))
[docs]
def crs(self, bands: BandQuery = None) -> Optional[CRS]:
"""
First non-null CRS across assets.
"""
for gbox in self.geoboxes(bands):
if gbox.crs is not None:
return gbox.crs
return None
[docs]
def image_geometry(
self,
crs: MaybeCRS = Unset(),
bands: BandQuery = None,
) -> Optional[Geometry]:
"""
Extract footprint of a given band(s) from proj metadata in a given projection.
"""
if isinstance(crs, Unset):
crs = None
for gbox in self.geoboxes(bands):
if gbox.crs is not None:
if crs is None or crs == gbox.crs:
return gbox.extent
return gbox.footprint(crs)
return None
[docs]
def safe_geometry(
self,
crs: MaybeCRS = Unset(),
bands: BandQuery = None,
) -> Optional[Geometry]:
"""
Get item geometry footprint in desired projection or native.
1. Use full-image footprint if proj data is available
2. Fallback to item geometry if not
"""
img_geom = self.image_geometry(crs, bands=bands)
if img_geom is not None:
return img_geom
if self.geometry is None:
return None
if crs is None or isinstance(crs, Unset):
return self.geometry
N = 100 # minimum number of points along perimiter we desire
min_sample_distance = math.sqrt(self.geometry.area) * 4 / N
return self.geometry.to_crs(
crs,
min_sample_distance,
check_and_fix=True,
).dropna()
[docs]
def resolve_bands(
self, bands: BandQuery = None
) -> dict[str, RasterSource | AuxDataSource | None]:
"""
Query bands taking care of aliases.
"""
bands = self.collection.normalize_band_query(bands)
canon = self.collection.band_key
return {
k: self.bands.get(_actual, None)
for k, _actual in ((k, canon(k)) for k in bands)
}
@override
def __getitem__(self, band: BandIdentifier) -> RasterSource | AuxDataSource:
"""
Query band taking care of aliases.
:raises: :py:class:`KeyError`
"""
if isinstance(band, str):
band = self.collection.band_key(band)
return self.bands[band]
@override
def __len__(self) -> int:
return len(self.bands)
@override
def __iter__(self) -> Iterator[BandKey]:
yield from self.bands
@override
def __contains__(self, k: object) -> bool:
if isinstance(k, str):
try:
return self.collection.band_key(k) in self.bands
except ValueError:
return False
if isinstance(k, tuple):
return k in self.bands
return False
@property
def nominal_datetime(self) -> dt.datetime:
"""
Resolve timestamp to a single value.
- datetime if set
- start_datetime if set
- end_datetime if set
- ``raise ValueError`` otherwise
"""
for ts in [self.datetime, *self.datetime_range]:
if ts is not None:
return ts
raise ValueError("Timestamp was not populated.")
@property
def mid_longitude(self) -> Optional[float]:
"""
Return longitude of the center point.
used for "solar day" computation.
"""
if self.geometry is None:
return None
((lon, _),) = self.geometry.centroid.to_crs("epsg:4326").points
return lon
@property
def solar_date(self) -> dt.datetime:
"""
Nominal datetime adjusted by longitude.
"""
lon = self.mid_longitude
if lon is None:
return self.nominal_datetime
return _convert_to_solar_time(self.nominal_datetime, lon)
[docs]
def solar_date_at(self, lon: float) -> dt.datetime:
"""
Nominal datetime adjusted by longitude.
"""
return _convert_to_solar_time(self.nominal_datetime, lon)
[docs]
def strip(self) -> "ParsedItem":
"""
Copy of self but with stripped bands.
"""
return replace(
self,
bands={k: band.strip() for k, band in self.bands.items()},
accessories={},
)
[docs]
def assets(self) -> Dict[str, List[RasterSource]]:
"""
Extract bands grouped by asset they belong to.
"""
assets: Dict[str, List[Tuple[int, RasterSource]]] = {}
for (asset, idx), src in self.bands.items():
if isinstance(src, RasterSource):
assets.setdefault(asset, []).append((idx, src))
return {
k: [src for _, src in sorted(srcs, key=lambda x: x[0])]
for k, srcs in assets.items()
}
@override
def __hash__(self) -> int:
return hash((self.id, self.collection.name))
def __dask_tokenize__(self):
return (
self.id,
self.collection,
self.bands,
self.href,
self.datetime,
self.datetime_range,
)
def _default_props_fuser(xx: Sequence[Any]) -> Any:
n = len(xx)
if n == 0:
return None
if n == 1:
return xx[0]
if isinstance(xx[0], str):
return ",".join((str(x) for x in xx))
xx = [x for x in xx if isinstance(x, (int, float)) and math.isfinite(x)]
if len(xx) == 0:
return None
if len(xx) == 1:
return xx[0]
return sum(xx) / len(xx)
@dataclass(frozen=True)
class PropertyLoadRequest:
"""
Request to load a property from STAC item as xarray DataArray.
Attributes:
key: The key of the property to load from STAC item
name: Name to use for output DataArray, if None will use the property key
dtype: Data type to use for loaded data, defaults to float32
"""
key: str
name: str | None = None
dtype: str = "float32"
nodata: float | None = None
units: str = "1"
fuser: Callable[[Sequence[Any]], Any] = _default_props_fuser
@staticmethod
def from_user_input(
inputs: Sequence[str | Mapping[str, Any]],
) -> list["PropertyLoadRequest"]:
"""
Create a list of PropertyLoadRequest objects from user input.
Args:
inputs: Sequence of either strings (property keys) or dictionaries with configuration.
Dictionaries must have 'key' defined, and can optionally have 'dtype' and 'name'.
Returns:
List of PropertyLoadRequest objects
Raises:
ValueError: If a dictionary input is missing the required 'key' field
"""
def _norm(what: str | Mapping[str, Any]) -> "PropertyLoadRequest":
if isinstance(what, str):
return PropertyLoadRequest(key=what)
if isinstance(what, dict):
if "key" not in what:
raise ValueError("Dictionary input must contain 'key' field")
return PropertyLoadRequest(**what)
raise ValueError(f"Input must be string or dict, got {type(what)}")
return [_norm(what) for what in inputs]
@property
def output_name(self) -> str:
if self.name is not None:
return self.name
return self.key.replace(".", "_").replace(":", "_").replace("-", "_")
@property
def fill_value(self) -> Any:
dtype = np.dtype(self.dtype)
if self.nodata is not None:
return dtype.type(self.nodata)
if dtype.kind == "f":
return dtype.type(float("nan"))
return dtype.type(0)
@dataclass(frozen=True)
class MDParseConfig:
"""Item parsing config."""
band_defaults: RasterBandMetadata = field(
default_factory=lambda: norm_band_metadata({})
)
band_cfg: Dict[str, RasterBandMetadata] = field(default_factory=dict)
aliases: Dict[str, BandKey] = field(default_factory=dict)
ignore_proj: bool = False
extra_dims: Dict[str, int] = field(default_factory=dict)
extra_coords: Sequence[FixedCoord] = ()
with_props: Sequence[PropertyLoadRequest] = field(default_factory=list)
@staticmethod
def from_dict(
cfg: Dict[str, Any], collection_id: str | None = None
) -> "MDParseConfig":
if collection_id is not None:
if "assets" in cfg: # Assume it's a single collection config
_cfg = copy(cfg)
else:
_cfg = copy(cfg.get("*", {}))
_cfg.update(cfg.get(collection_id, {}))
else:
_cfg = copy(cfg)
band_defaults, band_cfg = _norm_band_cfg(_cfg.get("assets", {}))
aliases = {
alias: ((band, 1) if isinstance(band, str) else band)
for alias, band in _cfg.get("aliases", {}).items()
}
ignore_proj: bool = _cfg.get("ignore_proj", False)
extra_dims: Dict[str, int] = _cfg.get("dims", {})
extra_coords: list[FixedCoord] = []
cc: dict[str, list[Any]] = _cfg.get("coords", {})
assert isinstance(cc, dict)
with_props = _cfg.get("with_properties", [])
assert isinstance(with_props, list)
for name, val in cc.items():
assert isinstance(val, list)
extra_coords.append(FixedCoord(name, val))
return MDParseConfig(
band_defaults=band_defaults,
band_cfg=band_cfg,
ignore_proj=ignore_proj,
aliases=aliases,
extra_dims=extra_dims,
extra_coords=tuple(extra_coords),
with_props=PropertyLoadRequest.from_user_input(with_props),
)
def _norm_band_cfg(
cfg: Dict[str, Any],
) -> Tuple[RasterBandMetadata, Dict[str, RasterBandMetadata]]:
fallback = norm_band_metadata(cfg.get("*", {}))
return fallback, {
k: norm_band_metadata(v, fallback) for k, v in cfg.items() if k != "*"
}
def _convert_to_solar_time(utc: dt.datetime, longitude: float) -> dt.datetime:
# offset_seconds snapped to 1 hour increments
# 1/15 == 24/360 (hours per degree of longitude)
offset_seconds = int(longitude / 15) * 3600
return utc + dt.timedelta(seconds=offset_seconds)