Source code for qa4sm_autoreports.series

# SPDX-License-Identifier: MIT
# SPDX-FileCopyrightText: Copyright (c) 2026 TU Wien & AWST
# SPDX-FileCopyrightText: For a full list of authors, see the AUTHORS file.

import warnings
import numpy as np
from pathlib import Path
from typing import Union
import os
import pandas as pd
import matplotlib.pyplot as plt

from qa4sm_autoreports.data import Data
from qa4sm_autoreports.report import AutoReportCreator
from qa4sm_api.client_api import Connection


[docs] class AutoReportSeries: def __init__(self, series_root, reports=None, connection=None): """ Cross-report collection with the same validation settings, datasets, report template etc. Parameters ---------- series_root: str or Path Root directory of local series run results reports: list[str] or int Subset of local validation report names (folders in series_root) to load only. If an int is passed, we load last n reports only connection: Connection, optional Connection to use for all reports. If None, connections will be created based on the instance in each report's config file. """ self.series_root = Path(series_root) self.name = self.series_root.name self.connection = connection if not self.series_root.exists(): raise ValueError(f"series_root {self.series_root} does not exist.") self.reports = self._load_local_reports(reports) def _load_local_reports(self, subset=None): """ Load report series from all report folders in the series_root directory. Parameters ---------- subset: list[str, ...] or int Subset of local validation report names (folders in series_root) to load only. If subset is an int, we load the last n reports only. Returns ------- reports: dict[str, AutoReportCreator] Report names and their corresponding AutoReportCreator objects. """ if isinstance(subset, int): dirs = [] for f in sorted(self.series_root.iterdir()): if f.is_dir(): dirs.append(f) subset = dirs[::-1][:subset] reports = {} for f in sorted(self.series_root.iterdir()): if f.is_dir(): if subset is not None: if f.name not in subset: continue r = AutoReportCreator.from_results( report_root=self.series_root / f.name, connection=self.connection) name = r.name reports[name] = r return reports def __len__(self) -> int: return len(self.reports) def __repr__(self): # List reports in this series and their status (not loaded, staged, s = ('AutoReportSeries\n' '----------------\n') i = 0 for name, r in self.reports.items(): # type: AutoReportCreator if isinstance(r, str): status = "DUMMY REPORT" else: status = r._STATUS_LUT[r.status] name = r.name s += f"Report {i} [{status}]: {str(name)}\n" i += 1 if i == 0: s += 'no reports in this series found\n...\n' asd = f"Local --> {self.series_root}>" s += '_' * len(asd) + "\n" s += asd return s def __getitem__(self, item: Union[int, str]) -> AutoReportCreator: # Load and return one report from the series by name or id name = self._name(item) self._load_by_name(name) return self.reports[name] def _name(self, r: Union[int, str]) -> str: # Get report name from id or name if isinstance(r, int): name = list(self.reports.keys())[r] elif isinstance(r, str): if r not in list(self.reports.keys()): raise KeyError(f"The report '{r}' is not part of " f"the collection") else: name = r # pass else: raise ValueError(f"Pass either report ID or a " f"name from {list(self.reports.keys())}") return name def _load_by_name(self, name): # (Re)load a single report by name from the list r = AutoReportCreator.from_results( report_root=self.series_root / name, connection=self.connection) self.reports[r.name] = r
[docs] def reports_complete(self) -> bool: """ Check whether all reports in the collection are complete i.e, collected. Returns ------- status: bool, True if all are done -> Series up-to-date """ s = [] for name, report in self.reports.items(): if report.status >= 2: s.append(True) else: s.append(False) return bool(np.all(s))
[docs] def new_report(self, report_name, config_template_path, override_params=None, instance="qa4sm.eu", token=None): """ Start a new validation report from config templates on the chosen instance, download and collect all results. Parameters ---------- report_name: str Name of the report (will be added to the list) config_template_path: Path or str Path where the .json templates are stored override_params: dict, optional Params to override settings in config file instance: str, optional Instance to use for the report token: str, optional API token for authentication. If None, uses the connection from the series if available, otherwise creates a new connection without token. """ if report_name in self.reports: raise KeyError(f"Report {report_name} already exists") if token is None and self.connection is not None: connection = self.connection else: connection = Connection(instance=instance, token=token) report = AutoReportCreator.from_scratch( self.series_root / report_name, config_template_path, connection=connection) if override_params is not None: report.override_params(**override_params) for run in report.runs.values(): run.config.dump(run.local_root / f'config-{instance}.json') self.reports[report.name] = report return report
[docs] def delete_report(self, report_name, remote=True): """ Delete report from series. By default, also deletes the online runs and local copies of the validation runs. Parameters ---------- report_name: str, int Name of the report to delete from the series remote: bool, optional Remove the online version of the respective validation runs """ self.reports[report_name].delete(remote=remote) self.reports = self._load_local_reports(list(self.reports.keys()))
@staticmethod def _select_epochs(epochs: list, ref_epoch: int, n_epoch: int) -> list: """ Select a subset of epochs relative to a reference epoch. Parameters ---------- epochs: list Sorted list of epoch strings, ordered from earliest to latest. ref_epoch: int Reference epoch index. Supports negative indexing (e.g. -2 selects the second-to-last epoch). n_epoch: int Total number of epochs to return, counting backwards from and including the reference epoch (e.g., n_epoch=3 returns the reference plus the 2 epochs preceding it). Returns ------- epochs: list Subset of ``epochs`` of length ``min(n_epoch, ref_idx + 1)``, ending at and including the reference epoch. """ ref_idx = ref_epoch if ref_epoch >= 0 else len(epochs) + ref_epoch start_idx = max(0, ref_idx - (n_epoch - 1)) # -1 because ref counts as one return epochs[start_idx:ref_idx + 1]
[docs] def track_metric(self, metric, ref_epoch=-1, n_epochs=10, run=None, path_out=None, pretty_name='ubRMSD', unit='m³m⁻³', p_mask_var=None, p_mask_thres=0.05, tsw='bulk', preprocess=None): """ Create metric tracking data and plot Parameters ---------- metric: str, optional Metric to track across the epochs. e.g. R_between_0-ISMN_and_1-C3S_combined ref_epoch: int or str, optional Reference epoch, i.e. latest one. -1 uses the last report (ordered by name). A number refers to the repoch index, a string to the report name n_epochs: int, optional Number of epochs BEFORE the reference epochs to include (includes the reference). run: str, optional If the metric should be used from a certain run (from all reports), indicate the run name here. None means we search the metric in all runs, and use the first one if it's contained in multiple runs for a single report. path_out: str or path Where the stored files are stored. None will store all results in the folder of the reference epoch. pretty_name: str, optional Display name of the metric, e.g. ubRMSD unit: str, optional Pretty unit, no brackets, e.g m³m⁻³ p_mask_var: str, optional To mask data points where p>thres, pass the p variable name here. The same can be achieved via the preprocess function. p_mask_thres: float, optional The p value threshold used for masking, only used when p_mask_var is passed. tsw: str, optional Temporal sub-window to use (netcdf dimension). Default is "bulk" preprocess: Callable, optional Apply to dataset after loading, can be used for e.g. p value masking. Must take and return a dataset. Example:: lambda ds: ds """ if isinstance(ref_epoch, str): ref_epoch = list(self.reports.keys()).index(ref_epoch) reports = self._select_epochs( list(self.reports.keys()), ref_epoch, n_epochs) path_out = path_out or self.series_root / reports[-1] / "tracking" os.makedirs(path_out, exist_ok=True) fname = path_out / f"tracking_{metric}.yml" sd = Data().from_yml(fname) if os.path.isfile(fname) else Data() all_stats = {} for report in reports: ds = self.reports[report].open_datasets() dat = None if run is not None: dat = ds[run] else: for n, d in ds.items(): if metric in d: dat = d.sel(tsw=tsw) break if preprocess is not None: dat = preprocess(dat) if dat is None: raise KeyError(f"Metric {metric} not found in any run nc.") stats = { 'q5': np.nan, 'q25': np.nan, 'q50': np.nan, 'q75': np.nan, 'q95': np.nan, 'mean': np.nan, 'std': np.nan, 'n': np.nan } ser = dat.to_pandas() if p_mask_var is not None: ser = ser.loc[ser[p_mask_var] <= p_mask_thres, :] ser = ser[metric].dropna() stats['n'] = len(ser.values) for q in ['q25', 'q50', 'q75']: try: quant = float(ser.quantile(float(q[1:]) / 100)) stats[q] = quant except Exception: warnings.warn(f"Quantile {q} could not be computed.") try: stats['mean'] = float(ser.mean()) stats['std'] = float(ser.std()) except Exception: warnings.warn("Mean could not be computed.") all_stats[report] = stats ## other_stats = {'tracking_status': 'green'} sd.add(other_stats, section='results') sd.add(all_stats, section='tracking') sd.dump(fname, overwrite=True) df = pd.DataFrame.from_dict(sd.data['tracking']) bxpstats = [] names = list(df.columns.values) if len(names) < n_epochs: names = [None] * (n_epochs - len(names)) + names for name in names: bxpstats.append({ 'label': name, 'whislo': df.loc["q5", name] if name is not None else np.nan, 'whishi': df.loc["q95", name] if name is not None else np.nan, 'med': df.loc["q50", name] if name is not None else np.nan, 'q1': df.loc["q25", name] if name is not None else np.nan, 'q3': df.loc["q75", name] if name is not None else np.nan, 'fliers': [], }) fig, ax = plt.subplots(figsize=(6, 4)) positions = np.arange(len(names)) + 1 # 1-based positions ax.bxp(bxpstats, positions=positions, showfliers=False) ax.set_xticks(positions) ax.set_xticklabels(names, rotation=90, ha='center') ax.set_title(f"{pretty_name} tracking") ax.set_ylabel(f"{pretty_name} [{unit}]") fig.savefig(path_out / f"tracking_{metric}.png", bbox_inches='tight')