Source code for geeViz.chartingLib

"""
Zonal Summary & Charting Library for GEE

geeViz.chartingLib provides a Python pipeline for running zonal statistics on
ee.Image / ee.ImageCollection objects and producing Plotly charts (time series,
bar, sankey). It mirrors the logic in the geeView JS frontend so that both human users and AI
agents have a clean, efficient API for this common workflow.
"""

"""
   Copyright 2026 Ian Housman

   Licensed under the Apache License, Version 2.0 (the "License");
   you may not use this file except in compliance with the License.
   You may obtain a copy of the License at

       http://www.apache.org/licenses/LICENSE-2.0

   Unless required by applicable law or agreed to in writing, software
   distributed under the License is distributed on an "AS IS" BASIS,
   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
   See the License for the specific language governing permissions and
   limitations under the License.
"""

# --------------------------------------------------------------------------
#  Zonal summary + charting pipeline (ported from area-charting.js)
# --------------------------------------------------------------------------

import math
import ee
import pandas
import plotly.graph_objects as go
from geeViz.gee2Pandas import robust_featureCollection_to_df


###########################################################################
#                              Constants
###########################################################################

SPLIT_STR = "----"
SANKEY_TRANSITION_SEP = "0990"

DEFAULT_PLOT_BGCOLOR = "#d6d1ca"
DEFAULT_PLOT_FONT = "Roboto"
DEFAULT_CHART_WIDTH = 575
DEFAULT_CHART_HEIGHT = 350

AREA_FORMAT_DICT = {
    "Percentage": {"mult": None, "label": "% Area", "places": 2, "scale": 30},
    "Hectares": {"mult": 0.09, "label": "ha", "places": 0, "scale": 30},
    "Acres": {"mult": 0.222395, "label": "Acres", "places": 0, "scale": 30},
    "Pixels": {"mult": 1.0, "label": "Pixels", "places": 0, "scale": 30},
}


###########################################################################
#                          Private helpers
###########################################################################


def _ensure_hex_color(color):
    """Prepend '#' if missing from a hex color string."""
    if color is None:
        return None
    color = str(color)
    if not color.startswith("#"):
        color = "#" + color
    return color


def _interpolate_palette(palette, n):
    """Interpolate a color palette to *n* colors (continuous ramp).

    Given a list of hex color stops, linearly interpolate between them to
    produce exactly *n* evenly-spaced colors.  Matches the JS min/max/palette
    ramp behaviour for ordinal-thematic bar charts.
    """
    if not palette or n <= 0:
        return []
    palette = [_ensure_hex_color(c) for c in palette]
    if n == 1:
        return [palette[0]]
    if len(palette) >= n:
        # Down-sample evenly
        return [palette[round(i * (len(palette) - 1) / (n - 1))] for i in range(n)]

    out = []
    for i in range(n):
        t = i / (n - 1)  # 0 … 1
        pos = t * (len(palette) - 1)
        lo = int(math.floor(pos))
        hi = min(lo + 1, len(palette) - 1)
        frac = pos - lo
        c_lo = palette[lo].lstrip("#")
        c_hi = palette[hi].lstrip("#")
        # Expand 3-char hex to 6-char
        if len(c_lo) == 3:
            c_lo = "".join(ch * 2 for ch in c_lo)
        if len(c_hi) == 3:
            c_hi = "".join(ch * 2 for ch in c_hi)
        r = int(int(c_lo[0:2], 16) * (1 - frac) + int(c_hi[0:2], 16) * frac)
        g = int(int(c_lo[2:4], 16) * (1 - frac) + int(c_hi[2:4], 16) * frac)
        b = int(int(c_lo[4:6], 16) * (1 - frac) + int(c_hi[4:6], 16) * frac)
        out.append(f"#{r:02x}{g:02x}{b:02x}")
    return out


def _format_period(period):
    """Format a transition period list like [1985,1987] -> '1985-1987' or '1985' if equal."""
    if isinstance(period, (list, tuple)) and len(period) == 2:
        if period[0] == period[1]:
            return str(period[0])
        return f"{period[0]}-{period[1]}"
    return str(period)


def _expand_thematic_reduce_regions(df, band_names, class_info, area_format, scale, split_str):
    """Expand histogram dict columns from reduceRegions into class-name columns."""
    scale_mult = (scale / AREA_FORMAT_DICT["Hectares"]["scale"]) ** 2

    out_rows = []
    for _, row in df.iterrows():
        out_row = {}
        # preserve any non-histogram columns (e.g. label/id)
        for col in df.columns:
            if col not in band_names:
                out_row[col] = row[col]

        for bn in band_names:
            histogram = row.get(bn)
            if histogram is None or not isinstance(histogram, dict):
                continue

            # For stacked band names like "2020----Land_Cover", look up class_info
            # by the original band name (the part after the SPLIT_STR prefix).
            original_bn = bn.split(split_str, 1)[-1] if split_str in bn else bn
            info = class_info.get(original_bn, {})
            class_values = info.get("class_values", [])
            class_names = info.get("class_names", [])
            value_to_name = dict(zip([str(v) for v in class_values], class_names))

            pixel_total = sum(histogram.values()) or 1

            for str_val, count in histogram.items():
                name = value_to_name.get(str_val, str_val)
                col_name = f"{bn}{split_str}{name}" if len(band_names) > 1 else name
                if area_format == "Percentage":
                    out_row[col_name] = round((count / pixel_total) * 100, 2)
                elif area_format == "Pixels":
                    out_row[col_name] = count
                else:
                    mult = AREA_FORMAT_DICT[area_format]["mult"] * scale_mult
                    out_row[col_name] = round(count * mult, AREA_FORMAT_DICT[area_format]["places"])

        out_rows.append(out_row)

    return pandas.DataFrame(out_rows)


###########################################################################
#                       Data pipeline functions
###########################################################################


[docs] def get_obj_info(ee_obj, band_names=None): """ Detect the type of a GEE object and read its thematic class metadata. Args: ee_obj (ee.Image or ee.ImageCollection): The GEE object to inspect. band_names (list, optional): Override the band names to use. Returns: dict: Keys ``obj_type``, ``band_names``, ``is_thematic``, ``class_info``, ``size``. ``class_info`` is ``{band_name: {class_values, class_names, class_palette}}`` """ obj_type = type(ee_obj).__name__ if obj_type == "ImageCollection": first_img = ee.Image(ee_obj.first()) size = ee_obj.size().getInfo() else: first_img = ee.Image(ee_obj) size = 1 if band_names is None: band_names = first_img.bandNames().getInfo() # Read class metadata from image properties props = first_img.toDictionary().getInfo() class_info = {} is_thematic = False for bn in band_names: values_key = f"{bn}_class_values" names_key = f"{bn}_class_names" palette_key = f"{bn}_class_palette" if values_key in props and names_key in props: is_thematic = True class_info[bn] = { "class_values": props[values_key], "class_names": props[names_key], "class_palette": props.get(palette_key, []), } return { "obj_type": obj_type, "band_names": band_names, "is_thematic": is_thematic, "class_info": class_info, "size": size, }
[docs] def detect_geometry_type(geometry): """ Determine whether the input geometry represents a single region or multiple. Args: geometry: An ``ee.Geometry``, ``ee.Feature``, or ``ee.FeatureCollection``. Returns: tuple: ``(geo_type, geometry)`` where geo_type is ``'single'`` or ``'multi'``, and geometry is an ``ee.Geometry`` (single) or ``ee.FeatureCollection`` (multi). """ type_name = type(geometry).__name__ if type_name == "Geometry": return ("single", geometry) if type_name == "Feature": return ("single", geometry.geometry()) if type_name == "FeatureCollection": size = geometry.size().getInfo() if size <= 1: return ("single", geometry.geometry()) return ("multi", geometry) # Fallback: try treating as geometry return ("single", ee.Geometry(geometry))
[docs] def prepare_for_reduction(ee_obj, obj_info, x_axis_property="system:time_start", date_format="YYYY"): """ Prepare a GEE object for reduction by stacking an ImageCollection into a single multi-band image. Args: ee_obj: ``ee.Image`` or ``ee.ImageCollection``. obj_info (dict): Output of :func:`get_obj_info`. x_axis_property (str): Property name to use for x-axis labels. date_format (str): Earth Engine date format string (e.g. ``'YYYY'``). Returns: tuple: ``(stacked_image, stack_band_names, x_axis_labels)`` """ band_names = obj_info["band_names"] if obj_info["obj_type"] == "ImageCollection": ic = ee_obj # Tag images with x_axis_property if it's a date-derived field if x_axis_property in ("year", "date", "system:time_start"): ic = ic.map(lambda img: img.set("year", img.date().format(date_format))) if x_axis_property in ("date", "system:time_start"): x_axis_property = "year" # Get the x-axis labels x_axis_labels = ic.aggregate_histogram(x_axis_property).keys().getInfo() # Select only the bands we care about ic = ic.select(band_names) # Group by x_axis_property - if multiple images per label, mosaic them label_counts = ic.aggregate_histogram(x_axis_property).getInfo() needs_mosaic = any(v > 1 for v in label_counts.values()) if needs_mosaic: print("Auto-mosaicking ImageCollection for x-axis labels...") def _mosaic_for_label(label): label = ee.String(label) filtered = ic.filter(ee.Filter.eq(x_axis_property, label)) return filtered.mosaic().copyProperties(filtered.first()).set(x_axis_property, label) ic = ee.ImageCollection(ee.List(x_axis_labels).map(_mosaic_for_label)) # Stack into single image with band names like "2020----forest" def _rename_bands(img): label = ee.String(img.get(x_axis_property)) new_names = ee.List(band_names).map( lambda bn: label.cat(SPLIT_STR).cat(ee.String(bn)) ) return img.select(band_names).rename(new_names) # Pre-compute expected band names: "label----band" for each label × band expected_names = [] for x_label in x_axis_labels: for bn in band_names: expected_names.append(f"{x_label}{SPLIT_STR}{bn}") ic = ic.map(_rename_bands) stacked = ic.toBands() # toBands() prefixes each band with the image's system:index + "_". # For programmatically-built collections (e.g. from the mosaic branch) # this is "0_", "1_", etc. But for collections with original system:index # values (e.g. "LC09_038029_20230613") the prefix is unpredictable. # Instead of trying to strip the prefix, rename to the expected names # we already know. stacked = stacked.rename(expected_names) return (stacked, expected_names, x_axis_labels) else: # Single image - pass through return (ee.Image(ee_obj).select(band_names), band_names, [])
[docs] def reduce_region(image, geometry, reducer, scale=30, crs=None, transform=None, tile_scale=4): """ Run ``image.reduceRegion`` with sensible defaults. If both ``scale`` and ``transform`` are provided, ``scale`` is set to None (transform takes precedence in GEE). Args: image (ee.Image): The image to reduce. geometry: An ``ee.Geometry`` or ``ee.Feature``. reducer (ee.Reducer): The reducer to apply. scale (int, optional): Pixel scale in meters. Defaults to 30. crs (str, optional): CRS string. Defaults to None. transform (list, optional): Affine transform. Defaults to None. tile_scale (int, optional): Tile scale for parallelism. Defaults to 4. Returns: dict: The reduction result dictionary. """ if transform is not None and scale is not None: scale = None return image.reduceRegion( reducer=reducer, geometry=geometry, scale=scale, crs=crs, crsTransform=transform, bestEffort=True, maxPixels=1e13, tileScale=tile_scale, ).getInfo()
[docs] def reduce_regions(image, features, reducer, scale=30, crs=None, transform=None, tile_scale=4): """ Run ``image.reduceRegions`` and return the result as a DataFrame. Args: image (ee.Image): The image to reduce. features (ee.FeatureCollection): The zones. reducer (ee.Reducer): The reducer to apply. scale (int, optional): Pixel scale in meters. Defaults to 30. crs (str, optional): CRS string. Defaults to None. transform (list, optional): Affine transform. Defaults to None. tile_scale (int, optional): Tile scale for parallelism. Defaults to 4. Returns: pandas.DataFrame: The reduction results. """ if transform is not None and scale is not None: scale = None result = image.reduceRegions( collection=features, reducer=reducer, scale=scale, crs=crs, crsTransform=transform, tileScale=tile_scale, ) return robust_featureCollection_to_df(result)
[docs] def parse_thematic_results(raw_dict, obj_info, x_axis_labels, area_format="Percentage", scale=30, split_str=SPLIT_STR): """ Parse frequency histogram reduction results into a DataFrame with class names as columns. Args: raw_dict (dict): Output of :func:`reduce_region` using ``frequencyHistogram``. obj_info (dict): Output of :func:`get_obj_info`. x_axis_labels (list): Labels for the x-axis (e.g. years). area_format (str): One of ``'Percentage'``, ``'Hectares'``, ``'Acres'``, ``'Pixels'``. scale (int): Pixel scale used in reduction. split_str (str): Band name separator. Returns: pandas.DataFrame: Rows are x-axis labels (or a single row for Image), columns are class names. """ class_info = obj_info["class_info"] band_names = obj_info["band_names"] scale_mult = (scale / AREA_FORMAT_DICT["Hectares"]["scale"]) ** 2 if x_axis_labels: # ImageCollection path - histogram keys are like "2020----Land_Cover" rows = [] for x_label in x_axis_labels: row = {"x": x_label} for bn in band_names: key = f"{x_label}{split_str}{bn}" histogram = raw_dict.get(key, {}) if histogram is None: histogram = {} info = class_info.get(bn, {}) class_values = info.get("class_values", []) class_names = info.get("class_names", []) value_to_name = dict(zip([str(v) for v in class_values], class_names)) pixel_total = sum(histogram.values()) or 1 for str_val, count in histogram.items(): name = value_to_name.get(str_val, str_val) col_name = f"{bn}{split_str}{name}" if len(band_names) > 1 else name if area_format == "Percentage": row[col_name] = round((count / pixel_total) * 100, 2) elif area_format == "Pixels": row[col_name] = count else: mult = AREA_FORMAT_DICT[area_format]["mult"] * scale_mult row[col_name] = round(count * mult, AREA_FORMAT_DICT[area_format]["places"]) rows.append(row) df = pandas.DataFrame(rows).set_index("x") df.index.name = None df = df.fillna(0) return df else: # Single Image path - histogram keys are band names directly row = {} for bn in band_names: histogram = raw_dict.get(bn, {}) if histogram is None: histogram = {} info = class_info.get(bn, {}) class_values = info.get("class_values", []) class_names = info.get("class_names", []) value_to_name = dict(zip([str(v) for v in class_values], class_names)) pixel_total = sum(histogram.values()) or 1 for str_val, count in histogram.items(): name = value_to_name.get(str_val, str_val) col_name = f"{bn}{split_str}{name}" if len(band_names) > 1 else name if area_format == "Percentage": row[col_name] = round((count / pixel_total) * 100, 2) elif area_format == "Pixels": row[col_name] = count else: mult = AREA_FORMAT_DICT[area_format]["mult"] * scale_mult row[col_name] = round(count * mult, AREA_FORMAT_DICT[area_format]["places"]) df = pandas.DataFrame([row]) df = df.fillna(0) return df
[docs] def parse_continuous_results(raw_dict, obj_info, x_axis_labels, split_str=SPLIT_STR): """ Parse continuous (mean/median/etc.) reduction results into a DataFrame. Args: raw_dict (dict): Output of :func:`reduce_region`. obj_info (dict): Output of :func:`get_obj_info`. x_axis_labels (list): Labels for the x-axis. split_str (str): Band name separator. Returns: pandas.DataFrame: Rows are x-axis labels (or single row), columns are band names. """ band_names = obj_info["band_names"] if x_axis_labels: rows = [] for x_label in x_axis_labels: row = {"x": x_label} for bn in band_names: key = f"{x_label}{split_str}{bn}" row[bn] = raw_dict.get(key) rows.append(row) df = pandas.DataFrame(rows).set_index("x") df.index.name = None return df else: row = {bn: raw_dict.get(bn) for bn in band_names} return pandas.DataFrame([row])
[docs] def zonal_stats( ee_obj, geometry, band_names=None, reducer=None, scale=30, crs=None, transform=None, tile_scale=4, area_format="Percentage", x_axis_property="system:time_start", date_format="YYYY", ): """ Compute zonal statistics for a GEE Image or ImageCollection over a geometry. This is the main entry point for the data pipeline. It auto-detects the object type, whether data is thematic or continuous, the appropriate reducer, and the geometry type. Args: ee_obj: ``ee.Image`` or ``ee.ImageCollection``. geometry: ``ee.Geometry``, ``ee.Feature``, or ``ee.FeatureCollection``. band_names (list, optional): Bands to include. Auto-detected if None. reducer (ee.Reducer, optional): Override the auto-selected reducer. scale (int): Pixel scale in meters. Defaults to 30. crs (str, optional): CRS string. transform (list, optional): Affine transform. tile_scale (int): Tile scale for parallelism. Defaults to 4. area_format (str): Area unit for thematic data. One of ``'Percentage'``, ``'Hectares'``, ``'Acres'``, ``'Pixels'``. x_axis_property (str): Property for x-axis labels (ImageCollection). date_format (str): Date format string for x-axis labels. Returns: pandas.DataFrame: The zonal statistics table. """ ee_obj = ee_obj.filterBounds(geometry) obj_info = get_obj_info(ee_obj, band_names) geo_type, geo = detect_geometry_type(geometry) # Choose reducer if reducer is None: if obj_info["is_thematic"]: reducer = ee.Reducer.frequencyHistogram() else: reducer = ee.Reducer.mean() # Determine if using frequency histogram is_histogram = False try: reducer_type = reducer.getInfo()["type"] is_histogram = "frequencyHistogram" in reducer_type except Exception: pass # Prepare image stacked, stack_bands, x_axis_labels = prepare_for_reduction( ee_obj, obj_info, x_axis_property, date_format ) if geo_type == "single": raw = reduce_region(stacked, geo, reducer, scale, crs, transform, tile_scale) if is_histogram: return parse_thematic_results(raw, obj_info, x_axis_labels, area_format, scale) else: return parse_continuous_results(raw, obj_info, x_axis_labels) else: # Multi-region: reduceRegions df = reduce_regions(stacked, geo, reducer, scale, crs, transform, tile_scale) if is_histogram: return _expand_thematic_reduce_regions( df, stack_bands, obj_info["class_info"], area_format, scale, SPLIT_STR ) else: return df
[docs] def prepare_sankey_data( ee_collection, band_name, transition_periods, class_info, geometry, scale=30, crs=None, transform=None, tile_scale=4, area_format="Percentage", min_percentage=0.2, ): """ Build a Sankey diagram dataset from class transitions across time periods. For each consecutive pair of periods, this function: 1. Filters the collection to each period 2. Computes the mode for each period 3. Creates a transition image encoding ``{from}0990{to}`` 4. Runs ``frequencyHistogram`` to count transitions 5. Parses results into both a source/target/value DataFrame and a transition matrix DataFrame Args: ee_collection (ee.ImageCollection): The input collection. band_name (str): The thematic band to analyze. transition_periods (list): List of ``[start_year, end_year]`` pairs. class_info (dict): Class info dict for the band (from :func:`get_obj_info`). geometry: ``ee.Geometry`` or ``ee.Feature``. scale (int): Pixel scale in meters. crs (str, optional): CRS string. transform (list, optional): Affine transform. tile_scale (int): Tile scale for parallelism. area_format (str): Area unit. min_percentage (float): Minimum percentage threshold for including a flow in the source-target table. The transition matrix always includes all observed transitions regardless of this threshold. Returns: tuple: ``(sankey_df, matrix_df)`` - **sankey_df** (``pandas.DataFrame``): Source-target-value table with columns ``source``, ``target``, ``value``, ``source_name``, ``target_name``, ``source_color``, ``target_color``, ``period``. Flows below ``min_percentage`` are excluded. - **matrix_df** (``pandas.DataFrame``): Transition matrix where rows are "from" classes (labelled ``"{period} {class_name}"``), columns are "to" classes, and values are the converted counts. One block of rows per consecutive period pair, matching the JS CSV export format. """ _, geo = detect_geometry_type(geometry) info = class_info.get(band_name, class_info.get(list(class_info.keys())[0], {})) class_values = info.get("class_values", []) class_names = info.get("class_names", []) class_palette = info.get("class_palette", []) value_to_idx = {v: i for i, v in enumerate(class_values)} idx_to_name = {i: n for i, n in enumerate(class_names)} idx_to_color = {i: _ensure_hex_color(c) for i, c in enumerate(class_palette)} num_classes = len(class_values) scale_mult = (scale / AREA_FORMAT_DICT["Hectares"]["scale"]) ** 2 all_rows = [] transition_band_names = [] # Build transition images for each consecutive period pair transition_images = [] period_labels = [] for i in range(len(transition_periods) - 1): p1 = transition_periods[i] p2 = transition_periods[i + 1] p1_start, p1_end = (p1, p1) if not isinstance(p1, (list, tuple)) else (p1[0], p1[-1]) p2_start, p2_end = (p2, p2) if not isinstance(p2, (list, tuple)) else (p2[0], p2[-1]) # Filter and compute mode for each period filtered1 = ee_collection.filter( ee.Filter.calendarRange(int(p1_start), int(p1_end), "year") ).select([band_name]) filtered2 = ee_collection.filter( ee.Filter.calendarRange(int(p2_start), int(p2_end), "year") ).select([band_name]) mode1 = filtered1.mode().rename(["from"]) mode2 = filtered2.mode().rename(["to"]) # Encode transition: from_class * 10000 + 9900 + to_class combined = mode1.addBands(mode2) transition = ( combined.select("from").multiply(10000) .add(9900) .add(combined.select("to")) .rename([f"{_format_period(p1)}---{_format_period(p2)}"]) ) transition_images.append(transition) transition_band_names.append(f"{_format_period(p1)}---{_format_period(p2)}") period_labels.append((_format_period(p1), _format_period(p2))) # Stack all transition images if len(transition_images) == 1: stacked = transition_images[0] else: stacked = transition_images[0] for t_img in transition_images[1:]: stacked = stacked.addBands(t_img) # Run frequency histogram raw = reduce_region( stacked.toInt(), geo, ee.Reducer.frequencyHistogram(), scale, crs, transform, tile_scale ) # Parse results — build both the source-target table and the transition matrix matrix_rows = [] for ti, t_bn in enumerate(transition_band_names): histogram = raw.get(t_bn, {}) if histogram is None: histogram = {} pixel_total = sum(histogram.values()) or 1 p1_label, p2_label = period_labels[ti] offset1 = ti * num_classes offset2 = (ti + 1) * num_classes # Build count_lookup: (from_idx, to_idx) -> display_val for ALL transitions count_lookup = {} for encoded_str, count in histogram.items(): encoded = int(float(encoded_str)) from_class = encoded // 10000 to_class = encoded % 10000 - 9900 from_idx = value_to_idx.get(from_class) to_idx = value_to_idx.get(to_class) if from_idx is None or to_idx is None: continue # Compute display value pct = (count / pixel_total) * 100 if area_format == "Percentage": display_val = round(pct, 2) elif area_format == "Pixels": display_val = count else: mult = AREA_FORMAT_DICT[area_format]["mult"] * scale_mult display_val = round(count * mult, AREA_FORMAT_DICT[area_format]["places"]) count_lookup[(from_idx, to_idx)] = display_val # Source-target table: only include flows above min_percentage if pct >= min_percentage: all_rows.append( { "source": from_idx + offset1, "target": to_idx + offset2, "value": display_val, "source_name": f"{p1_label} {idx_to_name.get(from_idx, str(from_class))}", "target_name": f"{p2_label} {idx_to_name.get(to_idx, str(to_class))}", "source_color": idx_to_color.get(from_idx, "#888888"), "target_color": idx_to_color.get(to_idx, "#888888"), "period": f"{p1_label} -> {p2_label}", } ) # Build transition matrix rows for this period pair # Columns are "to" class labels, rows are "from" class labels for fi in range(num_classes): row_label = f"{idx_to_name.get(fi, str(fi))} {p1_label}" row_data = {"": row_label} for ti2 in range(num_classes): col_label = f"{idx_to_name.get(ti2, str(ti2))} {p2_label}" row_data[col_label] = count_lookup.get((fi, ti2), 0) matrix_rows.append(row_data) # Build sankey_df empty_cols = ["source", "target", "value", "source_name", "target_name", "source_color", "target_color", "period"] if not all_rows: sankey_df = pandas.DataFrame(columns=empty_cols) else: sankey_df = pandas.DataFrame(all_rows) # Build matrix_df if matrix_rows: matrix_df = pandas.DataFrame(matrix_rows).set_index("") matrix_df.index.name = None else: matrix_df = pandas.DataFrame() return (sankey_df, matrix_df)
########################################################################### # Chart functions ###########################################################################
[docs] def chart_time_series( df, colors=None, chart_type="lines+markers", title="Time Series", x_label="Year", y_label=None, stacked=False, width=DEFAULT_CHART_WIDTH, height=DEFAULT_CHART_HEIGHT, label_max_length=30, ): """ Create a Plotly time series chart from a zonal stats DataFrame. Args: df (pandas.DataFrame): Output of :func:`zonal_stats` for an ImageCollection. Index = x-axis labels, columns = data series. colors (list, optional): Hex color strings for each column. chart_type (str): ``'lines'``, ``'bar'``, or ``'lines+markers'``. title (str): Chart title. x_label (str): X-axis label. y_label (str, optional): Y-axis label. stacked (bool): Whether to stack the series. width (int): Chart width in pixels. height (int): Chart height in pixels. label_max_length (int): Max characters for legend labels. Returns: plotly.graph_objects.Figure """ fig = go.Figure() x_values = list(df.index) # Convert pure-integer labels (e.g. years) to int so Plotly uses a # linear axis with automatic tick spacing instead of a categorical axis # that crams every label together. Mirrors the JS parseInt() logic. try: x_values = [int(v) for v in x_values] except (ValueError, TypeError): pass columns = list(df.columns) for i, col in enumerate(columns): color = None if colors and i < len(colors): color = _ensure_hex_color(colors[i]) label = col[:label_max_length] if chart_type == "bar": fig.add_trace( go.Bar( x=x_values, y=df[col].values, name=label, marker_color=color, ) ) else: mode = chart_type if chart_type in ("lines", "lines+markers") else "lines+markers" fig.add_trace( go.Scatter( x=x_values, y=df[col].values, mode=mode, name=label, line=dict(color=color, width=1), marker=dict(color=color, size=3), stackgroup="one" if stacked else None, ) ) bar_mode = "stack" if stacked and chart_type == "bar" else ("group" if chart_type == "bar" else None) fig.update_layout( title=dict(text=title, x=0.5, xanchor="center"), xaxis=dict( title=x_label, tickangle=45, # If x-values are integers (years), use integer tick format and # let Plotly auto-space the ticks instead of showing every value. tickformat="d" if all(isinstance(v, int) for v in x_values) else None, ), yaxis=dict( title=y_label, automargin=True, ), plot_bgcolor=DEFAULT_PLOT_BGCOLOR, paper_bgcolor=DEFAULT_PLOT_BGCOLOR, font=dict(family=DEFAULT_PLOT_FONT), width=width, height=height, margin=dict(l=35, r=25, b=50, t=50, pad=5), barmode=bar_mode, hovermode="x unified", ) return fig
[docs] def chart_bar( df, colors=None, title="Class Distribution", y_label=None, max_classes=30, width=DEFAULT_CHART_WIDTH, height=DEFAULT_CHART_HEIGHT, ): """ Create a Plotly bar chart from a single-Image zonal stats DataFrame. Automatically chooses horizontal or vertical orientation based on label length. Args: df (pandas.DataFrame): Output of :func:`zonal_stats` for a single Image. Single row, columns = class names. colors (list, optional): Hex color strings for each bar. title (str): Chart title. y_label (str, optional): Value axis label. max_classes (int): Maximum number of classes to display. width (int): Chart width in pixels. height (int): Chart height in pixels. Returns: plotly.graph_objects.Figure """ # Flatten to series if len(df) == 1: values = df.iloc[0] else: values = df.sum() labels = list(values.index) vals = list(values.values) # Cap at max_classes (keep top N by value) if len(labels) > max_classes: sorted_pairs = sorted(zip(vals, labels, range(len(labels))), reverse=True) sorted_pairs = sorted_pairs[:max_classes] sorted_pairs.sort(key=lambda x: x[2]) # restore original order vals = [p[0] for p in sorted_pairs] labels = [p[1] for p in sorted_pairs] # Also filter colors if colors: idxs = [p[2] for p in sorted_pairs] colors = [_ensure_hex_color(colors[i]) for i in idxs if i < len(colors)] if colors: if len(colors) < len(labels): # Interpolate palette as a continuous ramp (matches JS min/max/palette) colors = _interpolate_palette(colors, len(labels)) else: colors = [_ensure_hex_color(c) for c in colors[:len(labels)]] # Determine orientation max_label_len = max((len(str(l)) for l in labels), default=0) orientation = "h" if max_label_len > max(len(labels), 6) else "v" fig = go.Figure() if orientation == "h": fig.add_trace( go.Bar( y=labels, x=vals, orientation="h", marker_color=colors, ) ) fig.update_layout( xaxis=dict(title=y_label, automargin=True), yaxis=dict(automargin=True), margin=dict(l=80, r=25, b=30, t=50, pad=5), ) else: fig.add_trace( go.Bar( x=labels, y=vals, orientation="v", marker_color=colors, ) ) fig.update_layout( xaxis=dict(tickangle=45, automargin=True), yaxis=dict(title=y_label, automargin=True), margin=dict(l=35, r=25, b=80, t=50, pad=5), ) fig.update_layout( title=dict(text=title, x=0.5, xanchor="center"), plot_bgcolor=DEFAULT_PLOT_BGCOLOR, paper_bgcolor=DEFAULT_PLOT_BGCOLOR, font=dict(family=DEFAULT_PLOT_FONT), width=width, height=height, hovermode="closest", ) return fig
[docs] def chart_grouped_bar( df, colors=None, title="Zonal Summary by Feature", y_label=None, stacked=False, width=DEFAULT_CHART_WIDTH, height=DEFAULT_CHART_HEIGHT, ): """ Create a grouped (or stacked) bar chart for multi-feature zonal stats. Each group on the x-axis is a feature (row) and each bar/segment within the group is a class (column). This is the natural chart type when ``reduceRegions`` returns one row per zone. Args: df (pandas.DataFrame): Rows = features (index used as labels), columns = class names, values = numeric area/percentage. colors (list, optional): Hex color strings, one per column (class). title (str): Chart title. y_label (str, optional): Y-axis label. stacked (bool): Stack bars instead of grouping. Defaults to False. width (int): Chart width in pixels. height (int): Chart height in pixels. Returns: plotly.graph_objects.Figure """ fig = go.Figure() feature_labels = [str(v) for v in df.index] for i, col in enumerate(df.columns): color = None if colors and i < len(colors) and colors[i] is not None: color = _ensure_hex_color(colors[i]) fig.add_trace( go.Bar( name=str(col), x=feature_labels, y=df[col].values, marker_color=color, ) ) fig.update_layout( barmode="stack" if stacked else "group", title=dict(text=title, x=0.5, xanchor="center"), xaxis=dict(title="Feature", tickangle=45, automargin=True), yaxis=dict(title=y_label or "", automargin=True), plot_bgcolor=DEFAULT_PLOT_BGCOLOR, paper_bgcolor=DEFAULT_PLOT_BGCOLOR, font=dict(family=DEFAULT_PLOT_FONT), width=width, height=height, margin=dict(l=35, r=25, b=80, t=50, pad=5), hovermode="x unified", ) return fig
[docs] def chart_sankey( sankey_df, class_names, class_palette, transition_periods, title="Class Transitions", width=DEFAULT_CHART_WIDTH, height=DEFAULT_CHART_HEIGHT, node_thickness=35, node_pad=15, ): """ Create a Plotly Sankey diagram from transition data. Args: sankey_df (pandas.DataFrame): Output of :func:`prepare_sankey_data`. class_names (list): List of class names. class_palette (list): List of hex color strings. transition_periods (list): The transition period list used to generate the data. title (str): Chart title. width (int): Chart width in pixels. height (int): Chart height in pixels. node_thickness (int): Sankey node bar thickness. node_pad (int): Padding between Sankey nodes. Returns: plotly.graph_objects.Figure """ if sankey_df.empty: fig = go.Figure() fig.update_layout(title=title, annotations=[dict(text="No transitions found", showarrow=False)]) return fig # Build node labels and colors for all period slots num_periods = len(transition_periods) num_classes = len(class_names) labels = [] node_colors = [] for p in transition_periods: p_label = _format_period(p) for i, name in enumerate(class_names): labels.append(f"{p_label} {name}") color = _ensure_hex_color(class_palette[i]) if i < len(class_palette) else "#888888" node_colors.append(color) # Build link colors (average of source and target) link_colors = [] for _, row in sankey_df.iterrows(): sc = row.get("source_color", "#888888") link_colors.append(sc.replace("#", "rgba(") if False else sc) # use source color with alpha # Simple approach: use source color at reduced opacity hex_c = row.get("source_color", "#888888").lstrip("#") if len(hex_c) == 6: r, g, b = int(hex_c[:2], 16), int(hex_c[2:4], 16), int(hex_c[4:6], 16) link_colors[-1] = f"rgba({r},{g},{b},0.4)" fig = go.Figure( data=[ go.Sankey( textfont=dict(size=10), orientation="h", node=dict( pad=node_pad, thickness=node_thickness, line=dict(color="black", width=0.5), label=labels, color=node_colors, ), link=dict( source=list(sankey_df["source"]), target=list(sankey_df["target"]), value=list(sankey_df["value"]), color=link_colors, ), ) ] ) fig.update_layout( title=dict(text=title, x=0.5, xanchor="center"), font=dict(family=DEFAULT_PLOT_FONT, size=12), plot_bgcolor=DEFAULT_PLOT_BGCOLOR, paper_bgcolor=DEFAULT_PLOT_BGCOLOR, width=width, height=height, margin=dict(l=25, r=25, b=25, t=50, pad=0), ) return fig
########################################################################### # Convenience function ###########################################################################
[docs] def summarize_and_chart( ee_obj, geometry, band_names=None, reducer=None, scale=30, crs=None, transform=None, tile_scale=4, area_format="Percentage", x_axis_property="system:time_start", date_format="YYYY", title=None, chart_type="lines+markers", stacked=False, sankey=False, transition_periods=None, sankey_band_name=None, min_percentage=0.2, palette=None, feature_label=None, width=DEFAULT_CHART_WIDTH, height=DEFAULT_CHART_HEIGHT, ): """ Run zonal statistics and produce a chart in one call. Orchestrates :func:`zonal_stats` (or :func:`prepare_sankey_data`) and the appropriate chart function. Auto-picks chart type: bar for a single Image, time series for an ImageCollection, Sankey if ``sankey=True``. When ``feature_label`` is provided and the geometry is an ``ee.FeatureCollection`` with multiple features, the function uses ``reduceRegions`` to compute per-feature statistics and produces a grouped bar chart via :func:`chart_grouped_bar`. Each feature is labeled by the given property name. Args: ee_obj: ``ee.Image`` or ``ee.ImageCollection``. geometry: ``ee.Geometry``, ``ee.Feature``, or ``ee.FeatureCollection``. band_names (list, optional): Bands to include. reducer (ee.Reducer, optional): Override the auto-selected reducer. scale (int): Pixel scale in meters. crs (str, optional): CRS string. transform (list, optional): Affine transform. tile_scale (int): Tile scale for parallelism. area_format (str): Area unit for thematic data. x_axis_property (str): Property for x-axis labels. date_format (str): Date format string. title (str, optional): Chart title. Auto-generated if None. chart_type (str): ``'lines'``, ``'lines+markers'``, or ``'bar'``. stacked (bool): Whether to stack series. Defaults to False. sankey (bool): Whether to produce a Sankey diagram. transition_periods (list, optional): Period list for Sankey. sankey_band_name (str, optional): Band for Sankey analysis. min_percentage (float): Minimum percentage for Sankey flows. palette (list, optional): Hex color strings for each series/band. Overrides auto-detected class palette when provided. feature_label (str, optional): Property name to use as row labels when the geometry is a multi-feature ``ee.FeatureCollection``. Triggers the ``reduceRegions`` path and produces a grouped bar chart. width (int): Chart width in pixels. height (int): Chart height in pixels. Returns: tuple: For non-Sankey charts: ``(DataFrame, Figure)``. For Sankey charts: ``(sankey_df, Figure, matrix_df)`` where ``sankey_df`` is the source-target-value table and ``matrix_df`` is the from-class x to-class transition matrix. """ ee_obj = ee_obj.filterBounds(geometry) obj_info = get_obj_info(ee_obj, band_names) class_info = obj_info["class_info"] y_label = AREA_FORMAT_DICT.get(area_format, {}).get("label", area_format) if obj_info["is_thematic"] else None # Sankey path if sankey and obj_info["obj_type"] == "ImageCollection" and class_info: bn = sankey_band_name or obj_info["band_names"][0] if transition_periods is None: raise ValueError("transition_periods is required for Sankey charts") if title is None: title = f"{bn} Class Transitions" sankey_df, matrix_df = prepare_sankey_data( ee_obj, bn, transition_periods, class_info, geometry, scale=scale, crs=crs, transform=transform, tile_scale=tile_scale, area_format=area_format, min_percentage=min_percentage, ) info = class_info.get(bn, {}) fig = chart_sankey( sankey_df, class_names=info.get("class_names", []), class_palette=info.get("class_palette", []), transition_periods=transition_periods, title=title, width=width, height=height, ) return (sankey_df, fig, matrix_df) # Multi-feature path: reduceRegions + grouped bar chart geo_type, _ = detect_geometry_type(geometry) if geo_type == "multi" and feature_label: df = zonal_stats( ee_obj, geometry, band_names=band_names, reducer=reducer, scale=scale, crs=crs, transform=transform, tile_scale=tile_scale, area_format=area_format, x_axis_property=x_axis_property, date_format=date_format, ) # Set index to feature label column if feature_label in df.columns: df = df.set_index(feature_label) # Identify class columns from class_info class_cols = [] if class_info: for bn in obj_info["band_names"]: info = class_info.get(bn, {}) for name in info.get("class_names", []): col_name = f"{bn}{SPLIT_STR}{name}" if len(obj_info["band_names"]) > 1 else name if col_name in df.columns: class_cols.append(col_name) # Fallback: keep numeric columns that aren't geometry/system properties if not class_cols: class_cols = [ c for c in df.columns if pandas.api.types.is_numeric_dtype(df[c]) and not c.startswith("geometry") and c not in ("system:index",) ] chart_df = df[class_cols].fillna(0) # Build colors colors = palette if colors is None and class_info: color_lookup = {} for bn in obj_info["band_names"]: info = class_info.get(bn, {}) cn = info.get("class_names", []) cp = info.get("class_palette", []) for i, name in enumerate(cn): col_name = f"{bn}{SPLIT_STR}{name}" if len(obj_info["band_names"]) > 1 else name if i < len(cp): color_lookup[col_name] = cp[i] if color_lookup: colors = [color_lookup.get(col) for col in chart_df.columns] if title is None: title = "Zonal Summary by Feature" fig = chart_grouped_bar( chart_df, colors=colors, title=title, y_label=y_label, stacked=stacked, width=width, height=height, ) return (chart_df, fig) # Standard single-region zonal stats path df = zonal_stats( ee_obj, geometry, band_names=band_names, reducer=reducer, scale=scale, crs=crs, transform=transform, tile_scale=tile_scale, area_format=area_format, x_axis_property=x_axis_property, date_format=date_format, ) # Extract colors from class info (unless caller provided palette). # Build the color list to match actual DataFrame column order so that # multi-band thematic charts (e.g. Change + Land_Cover + Land_Use) # assign the correct color to each class. colors = palette if colors is None and class_info: # Build a lookup: column_name -> hex color color_lookup = {} for bn in obj_info["band_names"]: info = class_info.get(bn, {}) class_names = info.get("class_names", []) class_palette = info.get("class_palette", []) for i, name in enumerate(class_names): col_name = f"{bn}{SPLIT_STR}{name}" if len(obj_info["band_names"]) > 1 else name if i < len(class_palette): color_lookup[col_name] = class_palette[i] # Map each DataFrame column to its color (fall back to None) if color_lookup: colors = [color_lookup.get(col) for col in df.columns] # Pick chart type if obj_info["obj_type"] == "ImageCollection": if title is None: title = "Zonal Summary" fig = chart_time_series( df, colors=colors, chart_type=chart_type, title=title, x_label=x_axis_property.replace("_", " ").title() if x_axis_property != "year" else "Year", y_label=y_label, stacked=stacked, width=width, height=height, ) else: if title is None: title = "Class Distribution" fig = chart_bar( df, colors=colors, title=title, y_label=y_label, width=width, height=height, ) return (df, fig)