Source code for pythermalcomfort.plots.matplotlib.summary

"""Class-based summary plotting for threshold regions."""

from __future__ import annotations

from collections.abc import Mapping, Sequence
from dataclasses import dataclass
from typing import Any

import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from matplotlib.axes import Axes
from matplotlib.figure import Figure
from matplotlib.legend import Legend
from matplotlib.patches import Patch

from pythermalcomfort.plots.matplotlib._base import BasePlot
from pythermalcomfort.plots.matplotlib._shared import (
    _PYTHERMALCOMFORT_RC,
    BasePlotResult,
    _configure_regions,
    _is_light_color,
    _PlotDefaults,
)

# ── result container ───────────────────────────────────────────────────────


[docs] @dataclass class SummaryPlotResult(BasePlotResult): """Container with handles returned by :meth:`SummaryPlot.plot`. Attributes ---------- fig : Figure Matplotlib figure containing the summary plot. ax : Axes Matplotlib axis containing the summary plot. percentages : Series Percentage share per region, indexed by region label. artists : list List of rendered bar and text artists. legend : Legend or None Legend artist if ``legend=True``, otherwise ``None``. """ percentages: pd.Series artists: list[Any] legend: Legend | None
# ── validation helpers ───────────────────────────────────────────────────── def _validate_dataframe(df: pd.DataFrame) -> None: """Validate input DataFrame for summary plotting.""" if not isinstance(df, pd.DataFrame): raise TypeError("df must be a pandas DataFrame.") if df.empty: raise ValueError("df must contain at least one row.") def _validate_output_column(df: pd.DataFrame, output: str) -> str: """Validate output column name and ensure it exists.""" if not isinstance(output, str): raise TypeError("output must be a string.") output_name = output.strip() if not output_name: raise ValueError("output must be a non-empty string.") if output_name not in df.columns: msg = f"output column '{output_name}' was not found in the DataFrame." raise ValueError(msg) return output_name def _validate_output_values(df: pd.DataFrame, output_column: str) -> None: """Ensure output column contains numeric finite values only. Raises rather than silently dropping rows so callers are aware of missing data and can decide how to handle it before plotting. """ numeric_values = pd.to_numeric(df[output_column], errors="coerce") invalid_mask = ~numeric_values.notna() | ~np.isfinite(numeric_values.to_numpy()) if invalid_mask.any(): invalid_count = int(invalid_mask.sum()) msg = ( f"output column '{output_column}' contains {invalid_count} non-numeric, " "non-finite, or missing value(s)." ) raise ValueError(msg) # ── categorization ───────────────────────────────────────────────────────── def _compute_region_percentages( df: pd.DataFrame, *, output_column: str, levels: Sequence[float], region_labels: Sequence[str], ) -> pd.Series: """Assign each row to a threshold region and return percentage per region. Uses integer indices internally for pd.cut so that duplicate or empty display labels (e.g. ``["", "", ""]``) are handled correctly. The returned Series carries the display labels as its index. """ bins = [-np.inf, *levels, np.inf] values = pd.to_numeric(df[output_column], errors="raise") n_regions = len(levels) + 1 int_labels = list(range(n_regions)) categorized = pd.cut(values, bins=bins, labels=int_labels, right=False) result = ( categorized.value_counts(normalize=True) .reindex(int_labels, fill_value=0.0) .mul(100) .round(1) ) result.index = pd.Index(region_labels) return result # ── axis preparation ─────────────────────────────────────────────────────── def _prepare_axis(ax: Axes) -> None: """Prepare a clean, spine-free axis for summary bar rendering.""" ax.clear() ax.set_xticks([]) ax.set_yticks([]) for spine in ax.spines.values(): spine.set_visible(False) def _should_show_region_labels(*, legend: bool, region_labels: Sequence[str]) -> bool: """Return whether region labels should be drawn next to the bar.""" return not legend and any(label.strip() for label in region_labels) def _summary_figsize( *, vertical: bool, legend: bool, show_region_labels: bool, ) -> tuple[float, float]: """Return compact default figure size for standalone summary plots.""" D = _PlotDefaults.Summary if vertical: if show_region_labels: return D.figsize_vertical_labels return D.figsize_vertical_legend if legend else D.figsize_vertical_plain if legend: return D.figsize_horizontal_legend return ( D.figsize_horizontal_labels if show_region_labels else D.figsize_horizontal_plain ) def _apply_compact_layout(fig: Figure) -> None: """Trim default Matplotlib margins for standalone summary figures.""" fig.tight_layout(pad=_PlotDefaults.Summary.layout_pad) def _ensure_title_legend_spacing( fig: Figure, ax: Axes, legend: Legend | None, *, adjust_layout: bool, ) -> None: """Keep the title above the legend with only the measured gap needed.""" if legend is None or not ax.get_title(): return D = _PlotDefaults.Summary min_gap_px = D.title_legend_gap * fig.dpi / 72 for _ in range(D.title_spacing_iterations): fig.canvas.draw() renderer = fig.canvas.get_renderer() title_bbox = ax.title.get_window_extent(renderer) legend_bbox = legend.get_window_extent(renderer) current_gap_px = title_bbox.y0 - legend_bbox.y1 if current_gap_px >= min_gap_px: return axes_height_px = ax.get_window_extent(renderer).height if axes_height_px <= 0: return _, title_y = ax.title.get_position() y_adjustment = (min_gap_px - current_gap_px) / axes_height_px ax.title.set_y(title_y + y_adjustment) if adjust_layout: _apply_compact_layout(fig) # ── annotation helper ───────────────────────────────────────────────────── def _add_center_text( ax: Axes, *, x: float, y: float, text: str, color: str, ) -> Any: """Add a bold, centred text annotation.""" return ax.text( x, y, text, ha="center", va="center", fontsize=_PlotDefaults.Summary.percentage_fontsize, fontweight="bold", color=color, ) # ── unified summary renderer ────────────────────────────────────────────── def _plot_summary( ax: Axes, *, vertical: bool, region_percentages: pd.Series, region_labels: Sequence[str], region_colors: Sequence[str], show_region_labels: bool, bar_kws: Mapping[str, Any], ) -> list[Any]: """Render a stacked summary bar (horizontal or vertical) with annotations.""" D = _PlotDefaults.Summary artists: list[Any] = [] bar_opts = { "edgecolor": D.bar_edgecolor, "linewidth": D.bar_linewidth, **dict(bar_kws), } if vertical: ax.set_xlim(*(D.v_xlim if show_region_labels else D.v_xlim_legend)) ax.set_ylim(*D.v_ylim) else: ax.set_xlim(*D.h_xlim) ax.set_ylim(*(D.h_ylim if show_region_labels else D.h_ylim_legend)) cumulative = 0.0 for i, (label, color) in enumerate(zip(region_labels, region_colors, strict=False)): value = float(region_percentages.iloc[i]) if vertical: bar_args = { "x": D.v_bar_x, "height": value, "width": D.v_bar_width, "bottom": cumulative, "color": color, **bar_opts, } bar = ax.bar(**bar_args) else: bar_args = { "y": D.h_bar_y, "width": value, "left": cumulative, "height": D.h_bar_height, "color": color, **bar_opts, } bar = ax.barh(**bar_args) artists.append(bar) if value >= D.pct_min_to_show: is_light = _is_light_color(color) pct_color = "black" if is_light else "white" label_color = "dimgray" if is_light else color if vertical: center_y = cumulative + value / 2 pct_x, pct_y = D.v_bar_x, center_y lbl_x, lbl_y = D.v_bar_x + D.v_label_x_offset, center_y lbl_ha, lbl_va = "left", "center" else: pct_x, pct_y = cumulative + value / 2, D.h_bar_y lbl_x, lbl_y = cumulative + value / 2, D.h_label_y lbl_ha, lbl_va = "center", "bottom" artists.append( _add_center_text( ax, x=pct_x, y=pct_y, text=f"{value:.1f}%", color=pct_color ) ) if show_region_labels: artists.append( ax.text( lbl_x, lbl_y, label, ha=lbl_ha, va=lbl_va, fontsize=D.label_fontsize, color=label_color, ) ) cumulative += value return artists def _default_legend_ncol(*, vertical: bool, n_labels: int) -> int: """Return a compact default legend column count.""" D = _PlotDefaults.Summary if vertical: return D.vertical_legend_ncol return min(n_labels, D.legend_ncol_max) # ── public API ─────────────────────────────────────────────────────────────
[docs] class SummaryPlot(BasePlot): """Build and render a threshold summary plot from tabular model outputs. The class works with an existing DataFrame that already contains the target model output column (e.g., ``pmv`` or ``utci``). """
[docs] def __init__(self, df: pd.DataFrame) -> None: """Initialize a summary plot builder from a DataFrame. Parameters ---------- df : DataFrame Input DataFrame containing at least one output column to summarize. Raises ------ TypeError If ``df`` is not a pandas DataFrame. ValueError If ``df`` is empty. """ super().__init__() _validate_dataframe(df) self._df = df
[docs] def set_regions( self, *, output: str, thresholds: Sequence[float], labels: Sequence[str] | None = None, colors: Sequence[str] | None = None, ) -> SummaryPlot: """Set output variable and threshold region configuration. Parameters ---------- output : str Name of the DataFrame column to categorize. thresholds : sequence of float Numeric boundary values that divide the output range into regions. labels : sequence of str, optional Region labels. Must have length ``len(thresholds) + 1`` when provided. colors : sequence of str, optional Region colors. Must have length ``len(thresholds) + 1`` when provided. Returns ------- SummaryPlot Self, to support method chaining. Raises ------ TypeError If ``output`` is not a string. ValueError If the output column is missing or has invalid values, or if thresholds/labels/colors are invalid. """ output_name = _validate_output_column(self._df, output) _validate_output_values(self._df, output_name) self._region_config = _configure_regions( output=output_name, thresholds=thresholds, labels=labels, colors=colors, ) return self
[docs] def plot( self, *, ax: Axes | None = None, title: str | None = None, vertical: bool = False, legend: bool = True, bar_kws: Mapping[str, Any] | None = None, legend_kws: Mapping[str, Any] | None = None, ) -> SummaryPlotResult: """Render a threshold summary plot for the configured output column. Parameters ---------- ax : Axes, optional Existing axis to draw on. If ``None``, a new figure/axis is created with a compact default size for the selected orientation. title : str, optional Optional axis title. When both *title* and *legend* are shown the legend sits just above the chart and the title floats above it, matching the spacing used by :class:`ThresholdPlot`. vertical : bool If ``True``, render a vertical stacked bar; otherwise horizontal. legend : bool Whether to draw a colour-coded legend above the bar. When ``True``, region labels are omitted from the bar itself. bar_kws : dict, optional Keyword overrides forwarded to ``ax.bar`` or ``ax.barh`` for the stacked bar segments after SummaryPlot defaults are applied. legend_kws : dict, optional Keyword overrides forwarded to ``ax.legend``. Returns ------- SummaryPlotResult Result with figure, axis, percentages, artists, and legend handle. Raises ------ ValueError If regions are not configured first via :meth:`set_regions`. """ with mpl.rc_context(_PYTHERMALCOMFORT_RC): if self._region_config is None: raise ValueError( "Regions are not set. Call set_regions(...) before plot(...)." ) rc = self._region_config show_region_labels = _should_show_region_labels( legend=legend, region_labels=rc.labels, ) created_figure = ax is None if created_figure: fig, ax = plt.subplots( figsize=_summary_figsize( vertical=vertical, legend=legend, show_region_labels=show_region_labels, ) ) else: fig = ax.figure percentages = _compute_region_percentages( self._df, output_column=rc.output_name, levels=rc.thresholds, region_labels=rc.labels, ) _prepare_axis(ax) artists = _plot_summary( ax, vertical=vertical, region_percentages=percentages, region_labels=rc.labels, region_colors=rc.colors, show_region_labels=show_region_labels, bar_kws=bar_kws or {}, ) legend_artist: Legend | None = None if legend: lg_opts = dict(legend_kws or {}) lg_opts.setdefault("loc", "lower center") lg_opts.setdefault( "bbox_to_anchor", _PlotDefaults.legend_bbox_to_anchor_with_title if title is not None else _PlotDefaults.Threshold.legend_bbox_to_anchor, ) lg_opts.setdefault( "ncol", _default_legend_ncol(vertical=vertical, n_labels=len(rc.labels)), ) handles = [ Patch(facecolor=color, label=label) for label, color in zip(rc.labels, rc.colors, strict=False) ] legend_artist = ax.legend(handles=handles, **lg_opts) if title is not None: ax.set_title( title, y=_PlotDefaults.title_y_with_legend if legend else None ) if created_figure: _apply_compact_layout(fig) _ensure_title_legend_spacing( fig, ax, legend_artist, adjust_layout=created_figure, ) return SummaryPlotResult( fig=fig, ax=ax, percentages=percentages, artists=artists, legend=legend_artist, )