"""Metadata and data loading model classes."""
from __future__ import annotations
import datetime as dt
import math
from copy import copy
from dataclasses import astuple, dataclass, field, replace
from typing import Any, Dict, Iterator, List, Mapping, Optional, Sequence, Set, Tuple
from odc.geo import CRS, Geometry, MaybeCRS
from odc.geo.geobox import GeoBox
from odc.geo.types import Unset
from odc.loader.types import (
BandIdentifier,
BandKey,
BandQuery,
FixedCoord,
RasterBandMetadata,
RasterGroupMetadata,
RasterSource,
norm_band_metadata,
norm_key,
)
[docs]
@dataclass(eq=True, frozen=True)
class ParsedItem(Mapping[BandIdentifier, RasterSource]):
"""
Captures essentials parts for data loading from a STAC Item.
Only includes raster bands of interest.
"""
# pylint: disable=too-many-instance-attributes
id: str
"""Item id copied from STAC."""
collection: RasterCollectionMetadata
"""Collection this Item is part of."""
bands: Dict[BandKey, RasterSource]
"""Raster bands."""
geometry: Optional[Geometry] = None
"""Footprint of the dataset."""
datetime: Optional[dt.datetime] = None
"""Nominal timestamp."""
datetime_range: Tuple[Optional[dt.datetime], Optional[dt.datetime]] = None, None
"""Time period covered."""
href: Optional[str] = None
"""Self link from stac item."""
accessories: dict[str, Any] = field(default_factory=dict)
"""Additional assets"""
[docs]
def geoboxes(self, bands: BandQuery = None) -> Tuple[GeoBox, ...]:
"""
Unique ``GeoBox`` s, highest resolution first.
:param bands: which bands to consider, default is all
"""
bands = self.collection.normalize_band_query(bands)
def _resolution(g: GeoBox) -> float:
return min(g.resolution.map(abs).xy) # type: ignore
# TODO: support other geobox types?
gbx: Set[GeoBox] = set()
for name in bands:
b = self.bands.get(self.collection.band_key(name), None)
if b is not None:
if b.geobox is not None:
assert isinstance(b.geobox, GeoBox)
gbx.add(b.geobox)
return tuple(sorted(gbx, key=_resolution))
[docs]
def crs(self, bands: BandQuery = None) -> Optional[CRS]:
"""
First non-null CRS across assets.
"""
for gbox in self.geoboxes(bands):
if gbox.crs is not None:
return gbox.crs
return None
[docs]
def image_geometry(
self,
crs: MaybeCRS = Unset(),
bands: BandQuery = None,
) -> Optional[Geometry]:
"""
Extract footprint of a given band(s) from proj metadata in a given projection.
"""
if isinstance(crs, Unset):
crs = None
for gbox in self.geoboxes(bands):
if gbox.crs is not None:
if crs is None or crs == gbox.crs:
return gbox.extent
return gbox.footprint(crs)
return None
[docs]
def safe_geometry(
self,
crs: MaybeCRS = Unset(),
bands: BandQuery = None,
) -> Optional[Geometry]:
"""
Get item geometry footprint in desired projection or native.
1. Use full-image footprint if proj data is available
2. Fallback to item geometry if not
"""
img_geom = self.image_geometry(crs, bands=bands)
if img_geom is not None:
return img_geom
if self.geometry is None:
return None
if crs is None or isinstance(crs, Unset):
return self.geometry
N = 100 # minimum number of points along perimiter we desire
min_sample_distance = math.sqrt(self.geometry.area) * 4 / N
return self.geometry.to_crs(
crs,
min_sample_distance,
check_and_fix=True,
).dropna()
[docs]
def resolve_bands(
self, bands: BandQuery = None
) -> Dict[str, Optional[RasterSource]]:
"""
Query bands taking care of aliases.
"""
bands = self.collection.normalize_band_query(bands)
canon = self.collection.band_key
return {
k: self.bands.get(_actual, None)
for k, _actual in ((k, canon(k)) for k in bands)
}
def __getitem__(self, band: BandIdentifier) -> RasterSource:
"""
Query band taking care of aliases.
:raises: :py:class:`KeyError`
"""
if isinstance(band, str):
band = self.collection.band_key(band)
return self.bands[band]
def __len__(self) -> int:
return len(self.bands)
def __iter__(self) -> Iterator[BandKey]:
yield from self.bands
def __contains__(self, k: object) -> bool:
if isinstance(k, str):
try:
return self.collection.band_key(k) in self.bands
except ValueError:
return False
if isinstance(k, tuple):
return k in self.bands
return False
@property
def nominal_datetime(self) -> dt.datetime:
"""
Resolve timestamp to a single value.
- datetime if set
- start_datetime if set
- end_datetime if set
- ``raise ValueError`` otherwise
"""
for ts in [self.datetime, *self.datetime_range]:
if ts is not None:
return ts
raise ValueError("Timestamp was not populated.")
@property
def mid_longitude(self) -> Optional[float]:
"""
Return longitude of the center point.
used for "solar day" computation.
"""
if self.geometry is None:
return None
((lon, _),) = self.geometry.centroid.to_crs("epsg:4326").points
return lon
@property
def solar_date(self) -> dt.datetime:
"""
Nominal datetime adjusted by longitude.
"""
lon = self.mid_longitude
if lon is None:
return self.nominal_datetime
return _convert_to_solar_time(self.nominal_datetime, lon)
[docs]
def solar_date_at(self, lon: float) -> dt.datetime:
"""
Nominal datetime adjusted by longitude.
"""
return _convert_to_solar_time(self.nominal_datetime, lon)
[docs]
def strip(self) -> "ParsedItem":
"""
Copy of self but with stripped bands.
"""
return replace(self, bands={k: band.strip() for k, band in self.bands.items()})
[docs]
def assets(self) -> Dict[str, List[RasterSource]]:
"""
Extract bands grouped by asset they belong to.
"""
assets: Dict[str, List[Tuple[int, RasterSource]]] = {}
for (asset, idx), src in self.bands.items():
assets.setdefault(asset, []).append((idx, src))
return {
k: [src for _, src in sorted(srcs, key=lambda x: x[0])]
for k, srcs in assets.items()
}
def __hash__(self) -> int:
return hash((self.id, self.collection.name))
def __dask_tokenize__(self):
return (
self.id,
self.collection,
self.bands,
self.href,
self.datetime,
self.datetime_range,
)
@dataclass(frozen=True)
class MDParseConfig:
"""Item parsing config."""
band_defaults: RasterBandMetadata = field(
default_factory=lambda: norm_band_metadata({})
)
band_cfg: Dict[str, RasterBandMetadata] = field(default_factory=dict)
aliases: Dict[str, BandKey] = field(default_factory=dict)
ignore_proj: bool = False
extra_dims: Dict[str, int] = field(default_factory=dict)
extra_coords: Sequence[FixedCoord] = ()
@staticmethod
def from_dict(
cfg: Dict[str, Any], collection_id: str | None = None
) -> "MDParseConfig":
if collection_id is not None:
if "assets" in cfg: # Assume it's a single collection config
_cfg = copy(cfg)
else:
_cfg = copy(cfg.get("*", {}))
_cfg.update(cfg.get(collection_id, {}))
else:
_cfg = copy(cfg)
band_defaults, band_cfg = _norm_band_cfg(_cfg.get("assets", {}))
aliases = {
alias: ((band, 1) if isinstance(band, str) else band)
for alias, band in _cfg.get("aliases", {}).items()
}
ignore_proj: bool = _cfg.get("ignore_proj", False)
extra_dims: Dict[str, int] = _cfg.get("dims", {})
extra_coords: list[FixedCoord] = []
cc: dict[str, list[Any]] = _cfg.get("coords", {})
assert isinstance(cc, dict)
for name, val in cc.items():
assert isinstance(val, list)
extra_coords.append(FixedCoord(name, val))
return MDParseConfig(
band_defaults=band_defaults,
band_cfg=band_cfg,
ignore_proj=ignore_proj,
aliases=aliases,
extra_dims=extra_dims,
extra_coords=tuple(extra_coords),
)
def _norm_band_cfg(
cfg: Dict[str, Any]
) -> Tuple[RasterBandMetadata, Dict[str, RasterBandMetadata]]:
fallback = norm_band_metadata(cfg.get("*", {}))
return fallback, {
k: norm_band_metadata(v, fallback) for k, v in cfg.items() if k != "*"
}
def _convert_to_solar_time(utc: dt.datetime, longitude: float) -> dt.datetime:
# offset_seconds snapped to 1 hour increments
# 1/15 == 24/360 (hours per degree of longitude)
offset_seconds = int(longitude / 15) * 3600
return utc + dt.timedelta(seconds=offset_seconds)