Source code for geeViz.outputLib.charts

"""
Zonal Summary & Charting Library for GEE

geeViz.outputLib.charts provides a Python pipeline for running zonal statistics on
ee.Image / ee.ImageCollection objects and producing Plotly charts (time series,
bar, sankey). It mirrors the logic in the geeView JS frontend so that both
human users and AI agents have a clean, efficient API for this common workflow.

Quick start:

>>> import geeViz.geeView as gv
>>> from geeViz.outputLib import charts as cl
>>> ee = gv.ee
>>> study_area = ee.Geometry.Polygon(
...     [[[-106, 39.5], [-105, 39.5], [-105, 40.5], [-106, 40.5]]]
... )
>>> lcms = ee.ImageCollection("USFS/GTAC/LCMS/v2024-10")
>>> df, fig = cl.summarize_and_chart(
...     lcms.select(['Land_Cover']),
...     study_area,
...     stacked=True,
... )
>>> print(df.to_markdown())
>>> fig.write_html("chart.html", include_plotlyjs="cdn")

See :func:`summarize_and_chart` for the full API and more examples.
"""

"""
   Copyright 2026 Ian Housman

   Licensed under the Apache License, Version 2.0 (the "License");
   you may not use this file except in compliance with the License.
   You may obtain a copy of the License at

       http://www.apache.org/licenses/LICENSE-2.0

   Unless required by applicable law or agreed to in writing, software
   distributed under the License is distributed on an "AS IS" BASIS,
   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
   See the License for the specific language governing permissions and
   limitations under the License.
"""

# --------------------------------------------------------------------------
#  Zonal summary + charting pipeline (ported from area-charting.js)
# --------------------------------------------------------------------------

import math
import ee
import pandas
import plotly.graph_objects as go
from geeViz.gee2Pandas import robust_featureCollection_to_df


###########################################################################
#                              Constants
###########################################################################

SPLIT_STR = "----"
SANKEY_TRANSITION_SEP = "0990"

DEFAULT_PLOT_BGCOLOR = "rgba(0,0,0,0)"
DEFAULT_PLOT_FONT = "Roboto"
DEFAULT_CHART_WIDTH = 800
DEFAULT_CHART_HEIGHT = 600

#: Valid chart type strings for ``chart_type`` / ``chart_types`` parameters.
#: Use these in ``summarize_and_chart(chart_type=...)`` or
#: ``Report.add_section(chart_types=[...])``.
CHART_TYPES = [
    "bar",
    "stacked_bar",
    "line",
    "line+markers",
    "stacked_line",
    "stacked_line+markers",
    "donut",
    "scatter",
    "sankey",
]


def _legend_kwargs(legend_position):
    """Return Plotly legend layout dict.

    Args:
        legend_position: Either a dict of raw Plotly legend properties
            (e.g. ``{"orientation": "h", "x": 0.5, "y": -0.1}``),
            or ``None`` for Plotly defaults.

    Returns:
        dict: Plotly ``legend`` layout dict.
    """
    if legend_position is None or legend_position == "right":
        return {}
    if isinstance(legend_position, dict):
        return dict(legend_position)
    return {}


AREA_FORMAT_DICT = {
    "Percentage": {"mult": None, "label": "% Area", "places": 2, "scale": 30},
    "Hectares": {"mult": 0.09, "label": "ha", "places": 0, "scale": 30},
    "Acres": {"mult": 0.222395, "label": "Acres", "places": 0, "scale": 30},
    "Pixels": {"mult": 1.0, "label": "Pixels", "places": 0, "scale": 30},
}


###########################################################################
#                          Private helpers
###########################################################################

# Unified chart_type values and parser
_VALID_CHART_TYPES = {
    "bar", "stacked_bar",
    "line", "stacked_line",
    "line+markers", "stacked_line+markers",
}


def _parse_chart_type(chart_type):
    """Parse a unified *chart_type* string into ``(plotly_mode, is_stacked)``.

    Args:
        chart_type (str): One of ``"bar"``, ``"stacked_bar"``, ``"line"``,
            ``"stacked_line"``, ``"line+markers"``,
            ``"stacked_line+markers"``.

    Returns:
        tuple: ``(plotly_mode, is_stacked)`` where *plotly_mode* is
        ``"bar"``, ``"lines"``, or ``"lines+markers"`` and *is_stacked* is
        a bool.
    """
    ct = str(chart_type).lower().strip()

    # Backward compat: old values without stacked_ prefix
    # "lines" -> "line", "lines+markers" -> "line+markers"
    if ct == "lines":
        ct = "line"
    elif ct == "lines+markers":
        ct = "line+markers"

    is_stacked = ct.startswith("stacked_")
    base = ct.removeprefix("stacked_")

    mode_map = {"bar": "bar", "line": "lines", "line+markers": "lines+markers"}
    plotly_mode = mode_map.get(base, "lines+markers")
    return plotly_mode, is_stacked


def _ensure_hex_color(color):
    """Prepend '#' if missing from a hex color string."""
    if color is None:
        return None
    color = str(color)
    if not color.startswith("#"):
        color = "#" + color
    return color


def _title_to_filename(title):
    """Convert a chart title to a safe filename (no extension)."""
    import re
    if not title:
        return "chart"
    return re.sub(r'[^a-zA-Z0-9_-]', '_', title).strip('_')[:80] or "chart"


def _plotly_download_config(fig):
    """Build Plotly config dict with download filename derived from chart title."""
    title = ""
    if fig.layout.title and fig.layout.title.text:
        title = fig.layout.title.text
    fname = _title_to_filename(title)
    return {"toImageButtonOptions": {"filename": fname}}


def _set_download_filename(fig):
    """Patch a Plotly figure so ``fig.show()`` and ``fig.to_html()`` use the title as download filename.

    Wraps both methods to inject ``config={'toImageButtonOptions': {'filename': ...}}``
    automatically, so users don't need to pass config manually.
    """
    _orig_show = fig.show
    _orig_to_html = fig.to_html

    def _patched_show(*args, **kwargs):
        if "config" not in kwargs:
            kwargs["config"] = _plotly_download_config(fig)
        else:
            cfg = kwargs["config"]
            if "toImageButtonOptions" not in cfg:
                cfg["toImageButtonOptions"] = _plotly_download_config(fig)["toImageButtonOptions"]
        return _orig_show(*args, **kwargs)

    def _patched_to_html(*args, **kwargs):
        if "config" not in kwargs:
            kwargs["config"] = _plotly_download_config(fig)
        else:
            cfg = kwargs["config"]
            if "toImageButtonOptions" not in cfg:
                cfg["toImageButtonOptions"] = _plotly_download_config(fig)["toImageButtonOptions"]
        return _orig_to_html(*args, **kwargs)

    fig.show = _patched_show
    fig.to_html = _patched_to_html
    return fig


def _thin_tick_vals(tick_vals, max_ticks=10):
    """Return a subset of *tick_vals* so that at most *max_ticks* are shown.

    Chooses a stride of 1, 2, 5, 10, 20, 50, … (the smallest that keeps
    the count at or below *max_ticks*), always including the first and last
    values.  Returns ``None`` when no thinning is needed.
    """
    if max_ticks is None or max_ticks <= 0 or len(tick_vals) <= max_ticks:
        return None  # no thinning needed
    n = len(tick_vals)
    # Generate nice strides: 1, 2, 5, 10, 20, 50, 100, 200, 500, …
    magnitude = 1
    while magnitude < n:
        for base in [1, 2, 5]:
            stride = base * magnitude
            # Ticks: always include first & last, plus every stride-th index
            kept = [tick_vals[0]] + [tick_vals[i] for i in range(stride, n - 1, stride)] + [tick_vals[-1]]
            if len(kept) <= max_ticks:
                return kept
        magnitude *= 10
    # Fallback: just first and last
    return [tick_vals[0], tick_vals[-1]]


def _interpolate_palette(palette, n):
    """Interpolate a color palette to *n* colors (continuous ramp).

    Given a list of hex color stops, linearly interpolate between them to
    produce exactly *n* evenly-spaced colors.  Matches the JS min/max/palette
    ramp behaviour for ordinal-thematic bar charts.
    """
    if not palette or n <= 0:
        return []
    palette = [_ensure_hex_color(c) for c in palette]
    if n == 1:
        return [palette[0]]
    if len(palette) >= n:
        # Down-sample evenly
        return [palette[round(i * (len(palette) - 1) / (n - 1))] for i in range(n)]

    out = []
    for i in range(n):
        t = i / (n - 1)  # 0 … 1
        pos = t * (len(palette) - 1)
        lo = int(math.floor(pos))
        hi = min(lo + 1, len(palette) - 1)
        frac = pos - lo
        c_lo = palette[lo].lstrip("#")
        c_hi = palette[hi].lstrip("#")
        # Expand 3-char hex to 6-char
        if len(c_lo) == 3:
            c_lo = "".join(ch * 2 for ch in c_lo)
        if len(c_hi) == 3:
            c_hi = "".join(ch * 2 for ch in c_hi)
        r = int(int(c_lo[0:2], 16) * (1 - frac) + int(c_hi[0:2], 16) * frac)
        g = int(int(c_lo[2:4], 16) * (1 - frac) + int(c_hi[2:4], 16) * frac)
        b = int(int(c_lo[4:6], 16) * (1 - frac) + int(c_hi[4:6], 16) * frac)
        out.append(f"#{r:02x}{g:02x}{b:02x}")
    return out


def _format_period(period):
    """Format a transition period list like [1985,1987] -> '1985-1987' or '1985' if equal."""
    if isinstance(period, (list, tuple)) and len(period) == 2:
        if period[0] == period[1]:
            return str(period[0])
        return f"{period[0]}-{period[1]}"
    return str(period)


from geeViz.outputLib import themes as _themes
from geeViz.outputLib._templates import (
    render_chart_style as _render_chart_style,
    render_d3_sankey as _render_d3_sankey,
)


_PLOTLY_CDN_URL = "https://cdnjs.cloudflare.com/ajax/libs/plotly.js/1.33.1/plotly.min.js"


[docs] def save_chart_html(fig, filename, include_plotlyjs=_PLOTLY_CDN_URL, sankey=False, theme="dark", bg_color=None, font_color=None, **kwargs): """Save a chart to an HTML file. Accepts either a Plotly ``Figure`` or an HTML string (from ``summarize_and_chart(chart_type='sankey')``). Applies a theme so all chart types have a consistent look. Works both inside and outside the MCP sandbox. Args: fig: ``plotly.graph_objects.Figure`` or ``str`` (D3 sankey HTML from ``summarize_and_chart(chart_type='sankey')``). filename (str): Output filename (e.g. ``"chart.html"``). In the MCP sandbox, files are saved to ``generated_outputs/``. include_plotlyjs: How to include Plotly.js. Default ``"cdn"``. sankey (bool): Deprecated — ignored. Sankey charts are now returned as HTML strings and detected automatically. theme: Theme preset name, :class:`~geeViz.outputLib.themes.Theme` instance, or color string. Default ``"dark"``. bg_color: Background color override. font_color: Font/text color override. Returns: str: Path to the saved file. Examples: >>> path = cl.save_chart_html(fig, "ndvi_trend.html") >>> path = cl.save_chart_html(sankey_html, "sankey.html") >>> path = cl.save_chart_html(fig, "chart.html", theme="light") """ _t = _themes.get_theme(theme, bg_color=bg_color, font_color=font_color) # If fig is already an HTML string (from chart_sankey_d3), save directly if isinstance(fig, str): html = fig else: # Apply theme to a copy so we don't mutate the caller's figure import copy themed_fig = copy.deepcopy(fig) _themes.apply_plotly_theme(themed_fig, _t) html = themed_fig.to_html( full_html=True, include_plotlyjs=include_plotlyjs, config=_plotly_download_config(themed_fig), ) # Inject body background style _chart_style = _render_chart_style(_t) if "</head>" in html: html = html.replace("</head>", _chart_style + "</head>") elif "<body>" in html: html = html.replace("<body>", "<body>" + _chart_style) # Try MCP sandbox save_file first, fall back to direct write import builtins as _builtins _save_fn = _builtins.__dict__.get("save_file") if hasattr(_builtins, "__dict__") else None if _save_fn is None: # Check if save_file is in the caller's globals (MCP REPL injects it) import inspect frame = inspect.currentframe() try: caller_globals = frame.f_back.f_globals if frame.f_back else {} _save_fn = caller_globals.get("save_file") finally: del frame if _save_fn is not None: return _save_fn(filename, html) else: # Outside MCP sandbox — direct file write with open(filename, "w", encoding="utf-8") as f: f.write(html) return filename
[docs] def sankey_to_html(fig, full_html=True, include_plotlyjs=_PLOTLY_CDN_URL, renderer="d3", theme="dark", bg_color=None, font_color=None, hide_toolbar=False): """Return sankey HTML, accepting either a raw HTML string or legacy Plotly figure. Sankey charts from ``summarize_and_chart(chart_type='sankey')`` are now returned as D3 HTML strings directly. This function is kept for backward compatibility — it passes HTML strings through unchanged. Args: fig: D3 HTML string (preferred) or legacy Plotly ``Figure``. full_html (bool): Ignored for HTML strings. include_plotlyjs: Ignored for HTML strings. renderer (str): Ignored (always D3). theme: Theme preset for legacy Plotly figures. bg_color: Background color override. font_color: Font/text color override. hide_toolbar (bool): Hide the download button. Returns: str: HTML string. """ if isinstance(fig, str): return fig # Legacy path: Plotly figure with _gradient_color_map _t = _themes.get_theme(theme, bg_color=bg_color, font_color=font_color) return _sankey_plotly_fig_to_d3(fig, theme=_t, hide_toolbar=hide_toolbar)
def _sankey_plotly_fig_to_d3(fig, theme=None, hide_toolbar=False): """D3.js / d3-sankey based Sankey HTML with native SVG gradients.""" import json as _json _t = theme if theme is not None else _themes.get_theme("dark") # Extract data from the Plotly figure trace = fig.data[0] node_labels = list(trace.node.label) node_colors_raw = list(trace.node.color) link_sources = list(trace.link.source) link_targets = list(trace.link.target) link_values = list(trace.link.value) # Get gradient color map for source/target hex colors gradient_map = getattr(fig, "_gradient_color_map", {}) opacity = getattr(fig, "_gradient_link_opacity", 0.9) # Build link colors from gradient map or fall back to node colors link_colors = [] link_colors_raw = list(trace.link.color) if trace.link.color else [] for i in range(len(link_sources)): src_idx = link_sources[i] tgt_idx = link_targets[i] # Try to find source/target hex from gradient map if i < len(link_colors_raw): raw = link_colors_raw[i] import re m = re.match(r'rgba?\(\s*(\d+)\s*,\s*(\d+)\s*,\s*(\d+)', raw) if m: key = f"{m.group(1)},{m.group(2)},{m.group(3)}" if key in gradient_map: link_colors.append(gradient_map[key]) continue # Fallback: use node colors sc = node_colors_raw[src_idx] if src_idx < len(node_colors_raw) else "#888" tc = node_colors_raw[tgt_idx] if tgt_idx < len(node_colors_raw) else "#888" link_colors.append([sc, tc]) # Resolve node colors to hex (they may be rgba strings) node_colors_hex = [] for c in node_colors_raw: if c.startswith("rgba"): import re m = re.match(r'rgba?\(\s*(\d+)\s*,\s*(\d+)\s*,\s*(\d+)', c) if m: node_colors_hex.append( f"#{int(m.group(1)):02x}{int(m.group(2)):02x}{int(m.group(3)):02x}" ) continue node_colors_hex.append(c) # Extract layout info layout = fig.layout title = layout.title.text if layout.title and layout.title.text else "" width = layout.width or 800 height = layout.height or 600 node_thickness = trace.node.thickness or 20 node_pad = trace.node.pad or 15 # Filter out nodes that have no links (0-value orphans clutter the chart) used_indices = set() for i in range(len(link_sources)): if link_values[i] > 0: used_indices.add(link_sources[i]) used_indices.add(link_targets[i]) # Build old→new index mapping for used nodes only old_to_new = {} new_idx = 0 for old_idx in range(len(node_labels)): if old_idx in used_indices: old_to_new[old_idx] = new_idx new_idx += 1 # Build JSON data for the D3 template d3_data = { "nodes": [ {"name": node_labels[i], "color": node_colors_hex[i]} for i in range(len(node_labels)) if i in used_indices ], "links": [ { "source": old_to_new[link_sources[i]], "target": old_to_new[link_targets[i]], "value": link_values[i], "sourceColor": link_colors[i][0], "targetColor": link_colors[i][1], } for i in range(len(link_sources)) if link_values[i] > 0 and link_sources[i] in old_to_new and link_targets[i] in old_to_new ], } d3_config = { "title": title, "width": width, "height": height, "nodeWidth": node_thickness, "nodePadding": node_pad, "opacity": opacity, "bgColor": _t.bg_hex, "textColor": _t.text_hex, } # render_d3_sankey fills CSS color placeholders; we still need data/config html = _render_d3_sankey(_t) result = html.replace( "__D3_DATA_JSON__", _json.dumps(d3_data) ).replace( "__D3_CONFIG_JSON__", _json.dumps(d3_config) ) if hide_toolbar: result = result.replace( '<div id="toolbar">', '<div id="toolbar" style="display:none">', ) return result def _expand_thematic_reduce_regions(df, band_names, class_info, area_format, scale, split_str): """Expand histogram dict columns from reduceRegions into class-name columns.""" scale_mult = (scale / AREA_FORMAT_DICT["Hectares"]["scale"]) ** 2 out_rows = [] for _, row in df.iterrows(): out_row = {} # preserve any non-histogram columns (e.g. label/id) for col in df.columns: if col not in band_names: out_row[col] = row[col] for bn in band_names: histogram = row.get(bn) if histogram is None or not isinstance(histogram, dict): continue # For stacked band names like "2020----Land_Cover", look up class_info # by the original band name (the part after the SPLIT_STR prefix). original_bn = bn.split(split_str, 1)[-1] if split_str in bn else bn info = class_info.get(original_bn, {}) class_values = info.get("class_values", []) class_names = info.get("class_names", []) value_to_name = dict(zip([str(v) for v in class_values], class_names)) pixel_total = sum(histogram.values()) or 1 for str_val, count in histogram.items(): name = value_to_name.get(str_val, str_val) col_name = f"{bn}{split_str}{name}" if len(band_names) > 1 else name if area_format == "Percentage": out_row[col_name] = round((count / pixel_total) * 100, 2) elif area_format == "Pixels": out_row[col_name] = count else: mult = AREA_FORMAT_DICT[area_format]["mult"] * scale_mult out_row[col_name] = round(count * mult, AREA_FORMAT_DICT[area_format]["places"]) out_rows.append(out_row) return pandas.DataFrame(out_rows) def _pivot_multi_feature_timeseries(df, x_axis_labels, obj_info, feature_label, split_str=SPLIT_STR): """Pivot a multi-feature reduceRegions DataFrame (with stacked band columns) into per-feature time-series DataFrames. After ``reduceRegions`` on a stacked image, the DataFrame has one row per feature and columns like ``2020----Land_Cover----Trees`` (thematic) or ``2020----NDVI`` (continuous). This function restructures each row into a time-series DataFrame where the index is the time label and the columns are the class names (thematic) or band names (continuous). Args: df (pandas.DataFrame): Output of ``zonal_stats`` for multi-feature + ImageCollection. Must already have ``feature_label`` as a column or index. x_axis_labels (list): Time-step labels (e.g. ``['2020', '2021', ...]``). obj_info (dict): Output of :func:`get_obj_info`. feature_label (str): Column/index name identifying each feature. split_str (str, optional): Band name separator. Returns: dict: ``{feature_name: pandas.DataFrame}`` where each DataFrame has index = time labels, columns = class/band names. """ band_names = obj_info["band_names"] class_info = obj_info["class_info"] is_thematic = obj_info["is_thematic"] # Ensure feature_label is a column (not the index) if feature_label not in df.columns and df.index.name == feature_label: df = df.reset_index() result = {} for _, row in df.iterrows(): feat_name = str(row.get(feature_label, row.name)) rows_out = [] for x_label in x_axis_labels: ts_row = {"x": x_label} if is_thematic: # Columns are like "2020----Land_Cover----Trees" for bn in band_names: info = class_info.get(bn, {}) class_names = info.get("class_names", []) for cn in class_names: # Multi-band: "2020----Land_Cover----Trees" # Single-band: "2020----Land_Cover----Trees" still, because # _expand_thematic_reduce_regions uses stack_bands which # include the x_label prefix col = f"{x_label}{split_str}{bn}{split_str}{cn}" if col in row.index: ts_row[cn] = row[col] else: ts_row[cn] = 0 else: # Columns are like "2020----NDVI" for bn in band_names: col = f"{x_label}{split_str}{bn}" if col in row.index: ts_row[bn] = row[col] else: ts_row[bn] = None rows_out.append(ts_row) feat_df = pandas.DataFrame(rows_out).set_index("x") feat_df.index.name = None result[feat_name] = feat_df return result
[docs] def chart_multi_feature_timeseries( per_feature_dfs, colors=None, chart_type="line+markers", title="Time Series by Feature", x_label="Year", y_label=None, width=DEFAULT_CHART_WIDTH, height=None, columns=2, legend_position="bottom", line_width=2, marker_size=5, max_x_tick_labels=10, max_y_tick_labels=None, ): """Create a subplot figure with one time-series chart per feature. Features are arranged in a grid with *columns* columns (default 2). Each subplot gets ``height`` pixels tall (total height scales with number of rows). The legend defaults to ``"bottom"``. Args: per_feature_dfs (dict): ``{feature_name: DataFrame}`` from :func:`_pivot_multi_feature_timeseries`. colors (list, optional): Hex color strings for each column. chart_type (str, optional): ``"line+markers"`` (default), ``"line"``, ``"bar"``, ``"stacked_line"``, ``"stacked_line+markers"``, or ``"stacked_bar"``. title (str, optional): Overall chart title. x_label (str, optional): X-axis label. y_label (str, optional): Y-axis label. width (int, optional): Chart width in pixels. height (int, optional): Total chart height. When ``None`` each subplot gets 400 px. legend_position (dict or str, optional): Legend layout. Default ``"bottom"``. line_width (int or float, optional): Line width in pixels. Defaults to ``2``. marker_size (int or float, optional): Marker diameter in pixels. Defaults to ``5``. max_x_tick_labels (int, optional): Maximum number of x-axis tick labels per subplot. Labels are thinned to every 2nd, 5th, 10th, etc. value when exceeded. Defaults to ``10``. Set to ``None`` or ``0`` to disable. max_y_tick_labels (int, optional): Maximum number of y-axis tick labels per subplot. Uses Plotly's ``nticks``. Defaults to ``None`` (automatic). Returns: plotly.graph_objects.Figure """ from plotly.subplots import make_subplots plotly_mode, is_stacked = _parse_chart_type(chart_type) n = len(per_feature_dfs) if n == 0: return go.Figure() n_cols = min(columns, n) n_rows = -(-n // n_cols) # ceil division # height/width are per-cell; scale to full grid cell_h = height if height is not None else 400 cell_w = width if width is not None else DEFAULT_CHART_WIDTH height = n_rows * cell_h width = n_cols * cell_w feature_names = list(per_feature_dfs.keys()) fig = make_subplots( rows=n_rows, cols=n_cols, subplot_titles=feature_names, shared_xaxes=False, vertical_spacing=max(0.02, 0.10 / max(n_rows, 1)), horizontal_spacing=max(0.03, 0.08 / max(n_cols, 1)), ) # Track which legend entries we've already added so we only show each once legend_added = set() for feat_idx, (feat_name, feat_df) in enumerate(per_feature_dfs.items()): row_idx = feat_idx // n_cols + 1 col_idx_grid = feat_idx % n_cols + 1 x_values = list(feat_df.index) try: x_values = [int(v) for v in x_values] except (ValueError, TypeError): pass for col_idx, col in enumerate(feat_df.columns): color = None if colors and col_idx < len(colors): color = _ensure_hex_color(colors[col_idx]) show_legend = col not in legend_added legend_added.add(col) if plotly_mode == "bar": fig.add_trace( go.Bar( x=x_values, y=feat_df[col].values, name=col, marker_color=color, showlegend=show_legend, legendgroup=col, ), row=row_idx, col=col_idx_grid, ) else: fig.add_trace( go.Scatter( x=x_values, y=feat_df[col].values, mode=plotly_mode, name=col, line=dict(color=color, width=line_width), marker=dict(color=color, size=marker_size), stackgroup="one" if is_stacked else None, showlegend=show_legend, legendgroup=col, ), row=row_idx, col=col_idx_grid, ) bar_mode = "stack" if is_stacked and plotly_mode == "bar" else ("group" if plotly_mode == "bar" else None) # Legend: "bottom" → horizontal below chart legend_kw = _legend_kwargs(legend_position) if legend_position == "bottom": legend_kw = {"orientation": "h", "yanchor": "top", "y": -0.05, "xanchor": "center", "x": 0.5} fig.update_layout( title=dict(text=title, x=0.5, xanchor="center"), legend=legend_kw, plot_bgcolor=DEFAULT_PLOT_BGCOLOR, paper_bgcolor=DEFAULT_PLOT_BGCOLOR, font=dict(family=DEFAULT_PLOT_FONT), width=width, height=height, barmode=bar_mode, hovermode="x unified", ) # Fix x-axis ticks: only show actual data values, no interpolated ticks sample_idx = list(per_feature_dfs.values())[0].index if per_feature_dfs else [] is_int_axis = all(str(v).lstrip("-").isdigit() for v in sample_idx) if is_int_axis: all_tick_vals = sorted(set(int(v) for v in sample_idx)) tick_vals = all_tick_vals else: all_tick_vals = None tick_vals = None # Thin x-axis ticks if too many if tick_vals is not None: thinned = _thin_tick_vals(tick_vals, max_x_tick_labels) if thinned is not None: tick_vals = thinned for r in range(1, n_rows + 1): for c in range(1, n_cols + 1): kw = {} if tick_vals is not None: kw["tickvals"] = tick_vals kw["tickformat"] = "d" # Constrain range to full data extent to eliminate dead space kw["range"] = [min(all_tick_vals) - 0.5, max(all_tick_vals) + 0.5] if r == n_rows: kw["title_text"] = x_label fig.update_xaxes(row=r, col=c, **kw) # Label left column y-axes; add '%' suffix for percentage labels y_kw = {} if y_label and "%" in y_label: y_kw["ticksuffix"] = "%" if max_y_tick_labels is not None and max_y_tick_labels > 0: y_kw["nticks"] = max_y_tick_labels for r in range(1, n_rows + 1): if y_label: fig.update_yaxes(title_text=y_label, row=r, col=1, **y_kw) elif y_kw: fig.update_yaxes(row=r, col=1, **y_kw) _themes.apply_plotly_theme(fig, "dark") return fig
########################################################################### # Data pipeline functions ###########################################################################
[docs] def get_obj_info(ee_obj, band_names=None): """ Detect the type of a GEE object and read its thematic class metadata. Args: ee_obj (ee.Image or ee.ImageCollection): The GEE object to inspect. band_names (list, optional): Override the band names to use. Returns: dict: Keys ``obj_type``, ``band_names``, ``is_thematic``, ``class_info``, ``size``. ``class_info`` is ``{band_name: {class_values, class_names, class_palette}}`` Examples: Inspect LCMS to see its thematic class metadata: >>> import geeViz.geeView as gv >>> from geeViz.outputLib import charts as cl >>> ee = gv.ee >>> lcms = ee.ImageCollection("USFS/GTAC/LCMS/v2024-10") >>> info = cl.get_obj_info(lcms.select(['Land_Cover'])) >>> print(info['is_thematic']) True >>> print(info['class_info']['Land_Cover']['class_names']) ['Trees', 'Tall Shrubs & Trees Mix (AK Only)', ...] """ obj_type = type(ee_obj).__name__ if obj_type == "ImageCollection": first_img = ee.Image(ee_obj.first()) size = ee_obj.size().getInfo() else: first_img = ee.Image(ee_obj) size = 1 if band_names is None: band_names = first_img.bandNames().getInfo() # Read class metadata from image properties props = first_img.toDictionary().getInfo() class_info = {} is_thematic = False for bn in band_names: values_key = f"{bn}_class_values" names_key = f"{bn}_class_names" palette_key = f"{bn}_class_palette" if values_key in props and names_key in props: is_thematic = True class_info[bn] = { "class_values": props[values_key], "class_names": props[names_key], "class_palette": props.get(palette_key, []), } return { "obj_type": obj_type, "band_names": band_names, "is_thematic": is_thematic, "class_info": class_info, "size": size, }
def _detect_feature_label(fc): """Auto-detect a suitable label property from an ee.FeatureCollection. Looks for a property containing 'name' (case-insensitive), excluding system properties. Falls back to ``'system:index'``. Args: fc: ee.FeatureCollection. Returns: str: Property name to use as feature label. """ try: props = fc.first().propertyNames().getInfo() # Filter out system/geometry properties candidates = [p for p in props if not p.startswith("system:") and p != "geo"] # Prefer properties with "name" in them (case-insensitive) name_props = [p for p in candidates if "name" in p.lower()] if name_props: return name_props[0] except Exception: pass return "system:index"
[docs] def detect_geometry_type(geometry): """ Determine whether the input geometry represents a single region or multiple. Args: geometry: An ``ee.Geometry``, ``ee.Feature``, or ``ee.FeatureCollection``. Returns: tuple: ``(geo_type, geometry)`` where geo_type is ``'single'`` or ``'multi'``, and geometry is an ``ee.Geometry`` (single) or ``ee.FeatureCollection`` (multi). """ if isinstance(geometry, ee.Geometry): return ("single", geometry) if isinstance(geometry, ee.Feature): return ("single", geometry.geometry()) if isinstance(geometry, ee.FeatureCollection): size = geometry.size().getInfo() if size <= 1: return ("single", geometry.geometry()) return ("multi", geometry) # Fallback: ee.Element (from fc.first()) or other ComputedObject # Wrap in ee.Feature to extract geometry safely try: return ("single", ee.Feature(geometry).geometry()) except Exception: return ("single", ee.Geometry(geometry))
[docs] def prepare_for_reduction(ee_obj, obj_info, x_axis_property="system:time_start", date_format="YYYY"): """ Prepare a GEE object for reduction by stacking an ImageCollection into a single multi-band image. Args: ee_obj: ``ee.Image`` or ``ee.ImageCollection``. obj_info (dict): Output of :func:`get_obj_info`. x_axis_property (str, optional): Property name to use for x-axis labels. date_format (str, optional): Earth Engine date format string (e.g. ``'YYYY'``). Returns: tuple: ``(stacked_image, stack_band_names, x_axis_labels)`` """ band_names = obj_info["band_names"] if obj_info["obj_type"] == "ImageCollection": ic = ee_obj # Tag images with x_axis_property if it's a date-derived field if x_axis_property in ("year", "date", "system:time_start"): ic = ic.map(lambda img: img.set("year", img.date().format(date_format))) if x_axis_property in ("date", "system:time_start"): x_axis_property = "year" # Get the x-axis labels x_axis_labels = ic.aggregate_histogram(x_axis_property).keys().getInfo() # Select only the bands we care about ic = ic.select(band_names) # Group by x_axis_property - if multiple images per label, mosaic them label_counts = ic.aggregate_histogram(x_axis_property).getInfo() needs_mosaic = any(v > 1 for v in label_counts.values()) if needs_mosaic: print("Auto-mosaicking ImageCollection for x-axis labels...") def _mosaic_for_label(label): label = ee.String(label) filtered = ic.filter(ee.Filter.eq(x_axis_property, label)) return filtered.mosaic().copyProperties(filtered.first()).set(x_axis_property, label) ic = ee.ImageCollection(ee.List(x_axis_labels).map(_mosaic_for_label)) # Stack into single image with band names like "2020----forest" def _rename_bands(img): label = ee.String(img.get(x_axis_property)) new_names = ee.List(band_names).map( lambda bn: label.cat(SPLIT_STR).cat(ee.String(bn)) ) return img.select(band_names).rename(new_names) # Pre-compute expected band names: "label----band" for each label × band expected_names = [] for x_label in x_axis_labels: for bn in band_names: expected_names.append(f"{x_label}{SPLIT_STR}{bn}") ic = ic.map(_rename_bands) stacked = ic.toBands() # toBands() prefixes each band with the image's system:index + "_". # For programmatically-built collections (e.g. from the mosaic branch) # this is "0_", "1_", etc. But for collections with original system:index # values (e.g. "LC09_038029_20230613") the prefix is unpredictable. # Instead of trying to strip the prefix, rename to the expected names # we already know. stacked = stacked.rename(expected_names) return (stacked, expected_names, x_axis_labels) else: # Single image - pass through return (ee.Image(ee_obj).select(band_names), band_names, [])
[docs] def reduce_region(image, geometry, reducer, scale=30, crs=None, transform=None, tile_scale=4): """ Run ``image.reduceRegion`` with sensible defaults. If both ``scale`` and ``transform`` are provided, ``scale`` is set to None (transform takes precedence in GEE). Args: image (ee.Image): The image to reduce. geometry: An ``ee.Geometry`` or ``ee.Feature``. reducer (ee.Reducer): The reducer to apply. scale (int, optional): Pixel scale in meters. Defaults to 30. crs (str, optional): CRS string. Defaults to None. transform (list, optional): Affine transform. Defaults to None. tile_scale (int, optional): Tile scale for parallelism. Defaults to 4. Returns: dict: The reduction result dictionary. """ if transform is not None and scale is not None: scale = None return image.reduceRegion( reducer=reducer, geometry=geometry, scale=scale, crs=crs, crsTransform=transform, bestEffort=True, maxPixels=1e13, tileScale=tile_scale, ).getInfo()
[docs] def reduce_regions(image, features, reducer, scale=30, crs=None, transform=None, tile_scale=4): """ Run ``image.reduceRegions`` and return the result as a DataFrame. Args: image (ee.Image): The image to reduce. features (ee.FeatureCollection): The zones. reducer (ee.Reducer): The reducer to apply. scale (int, optional): Pixel scale in meters. Defaults to 30. crs (str, optional): CRS string. Defaults to None. transform (list, optional): Affine transform. Defaults to None. tile_scale (int, optional): Tile scale for parallelism. Defaults to 4. Returns: pandas.DataFrame: The reduction results. """ if transform is not None and scale is not None: scale = None result = image.reduceRegions( collection=features, reducer=reducer, scale=scale, crs=crs, crsTransform=transform, tileScale=tile_scale, ) return robust_featureCollection_to_df(result)
[docs] def parse_thematic_results(raw_dict, obj_info, x_axis_labels, area_format="Percentage", scale=30, split_str=SPLIT_STR): """ Parse frequency histogram reduction results into a DataFrame with class names as columns. Args: raw_dict (dict): Output of :func:`reduce_region` using ``frequencyHistogram``. obj_info (dict): Output of :func:`get_obj_info`. x_axis_labels (list): Labels for the x-axis (e.g. years). area_format (str, optional): One of ``'Percentage'``, ``'Hectares'``, ``'Acres'``, ``'Pixels'``. scale (int, optional): Pixel scale used in reduction. split_str (str, optional): Band name separator. Returns: pandas.DataFrame: Rows are x-axis labels (or a single row for Image), columns are class names. """ class_info = obj_info["class_info"] band_names = obj_info["band_names"] scale_mult = (scale / AREA_FORMAT_DICT["Hectares"]["scale"]) ** 2 if x_axis_labels: # ImageCollection path - histogram keys are like "2020----Land_Cover" rows = [] for x_label in x_axis_labels: row = {"x": x_label} for bn in band_names: key = f"{x_label}{split_str}{bn}" histogram = raw_dict.get(key, {}) if histogram is None: histogram = {} info = class_info.get(bn, {}) class_values = info.get("class_values", []) class_names = info.get("class_names", []) value_to_name = dict(zip([str(v) for v in class_values], class_names)) pixel_total = sum(histogram.values()) or 1 for str_val, count in histogram.items(): name = value_to_name.get(str_val, str_val) col_name = f"{bn}{split_str}{name}" if len(band_names) > 1 else name if area_format == "Percentage": row[col_name] = round((count / pixel_total) * 100, 2) elif area_format == "Pixels": row[col_name] = count else: mult = AREA_FORMAT_DICT[area_format]["mult"] * scale_mult row[col_name] = round(count * mult, AREA_FORMAT_DICT[area_format]["places"]) rows.append(row) df = pandas.DataFrame(rows).set_index("x") df.index.name = None df = df.fillna(0) return df else: # Single Image path - histogram keys are band names directly row = {} for bn in band_names: histogram = raw_dict.get(bn, {}) if histogram is None: histogram = {} info = class_info.get(bn, {}) class_values = info.get("class_values", []) class_names = info.get("class_names", []) value_to_name = dict(zip([str(v) for v in class_values], class_names)) pixel_total = sum(histogram.values()) or 1 for str_val, count in histogram.items(): name = value_to_name.get(str_val, str_val) col_name = f"{bn}{split_str}{name}" if len(band_names) > 1 else name if area_format == "Percentage": row[col_name] = round((count / pixel_total) * 100, 2) elif area_format == "Pixels": row[col_name] = count else: mult = AREA_FORMAT_DICT[area_format]["mult"] * scale_mult row[col_name] = round(count * mult, AREA_FORMAT_DICT[area_format]["places"]) df = pandas.DataFrame([row]) df = df.fillna(0) return df
[docs] def parse_continuous_results(raw_dict, obj_info, x_axis_labels, split_str=SPLIT_STR): """ Parse continuous (mean/median/etc.) reduction results into a DataFrame. Args: raw_dict (dict): Output of :func:`reduce_region`. obj_info (dict): Output of :func:`get_obj_info`. x_axis_labels (list): Labels for the x-axis. split_str (str, optional): Band name separator. Returns: pandas.DataFrame: Rows are x-axis labels (or single row), columns are band names. """ band_names = obj_info["band_names"] if x_axis_labels: rows = [] for x_label in x_axis_labels: row = {"x": x_label} for bn in band_names: key = f"{x_label}{split_str}{bn}" row[bn] = raw_dict.get(key) rows.append(row) df = pandas.DataFrame(rows).set_index("x") df.index.name = None return df else: row = {bn: raw_dict.get(bn) for bn in band_names} return pandas.DataFrame([row])
[docs] def zonal_stats( ee_obj, geometry, band_names=None, reducer=None, scale=30, crs=None, transform=None, tile_scale=4, area_format="Percentage", x_axis_property="system:time_start", date_format="YYYY", include_masked_area=True, ): """ Compute zonal statistics for a GEE Image or ImageCollection over a geometry. This is the main entry point for the data pipeline. It auto-detects the object type, whether data is thematic or continuous, the appropriate reducer, and the geometry type. Args: ee_obj: ``ee.Image`` or ``ee.ImageCollection``. geometry: ``ee.Geometry``, ``ee.Feature``, or ``ee.FeatureCollection``. band_names (list, optional): Bands to include. Auto-detected if None. reducer (ee.Reducer, optional): Override the auto-selected reducer. scale (int, optional): Pixel scale in meters. Defaults to 30. crs (str, optional): CRS string. transform (list, optional): Affine transform. tile_scale (int, optional): Tile scale for parallelism. Defaults to 4. area_format (str, optional): Area unit for thematic data. One of ``'Percentage'``, ``'Hectares'``, ``'Acres'``, ``'Pixels'``. x_axis_property (str, optional): Property for x-axis labels (ImageCollection). date_format (str, optional): Date format string for x-axis labels. Returns: pandas.DataFrame: The zonal statistics table. Examples: Get just the data (no chart) for an LCMS land cover time series: >>> import geeViz.geeView as gv >>> from geeViz.outputLib import charts as cl >>> ee = gv.ee >>> study_area = ee.Geometry.Polygon( ... [[[-106, 39.5], [-105, 39.5], [-105, 40.5], [-106, 40.5]]] ... ) >>> lcms = ee.ImageCollection("USFS/GTAC/LCMS/v2024-10") >>> df = cl.zonal_stats( ... lcms.select(['Land_Cover']), ... study_area, ... area_format='Percentage', ... ) >>> print(df.to_markdown()) Continuous data with a custom reducer: >>> df = cl.zonal_stats( ... lcms.select(['Change_Raw_Probability_Slow_Loss']), ... study_area, ... reducer=ee.Reducer.mean(), ... ) """ # filterBounds only applies to ImageCollections, not single Images if isinstance(ee_obj, ee.ImageCollection): ee_obj = ee_obj.filterBounds(geometry) obj_info = get_obj_info(ee_obj, band_names) geo_type, geo = detect_geometry_type(geometry) # Choose reducer if reducer is None: if obj_info["is_thematic"]: reducer = ee.Reducer.frequencyHistogram() else: reducer = ee.Reducer.mean() # Determine if using frequency histogram is_histogram = False try: reducer_type = reducer.getInfo()["type"] is_histogram = "frequencyHistogram" in reducer_type except Exception: pass # When include_masked_area=True and using histogram reducer, unmask # with a sentinel value (0) so masked pixels count toward the total # area denominator. The sentinel class is removed from results after. _unmask_sentinel = None if include_masked_area and is_histogram: _unmask_sentinel = 0 if isinstance(ee_obj, ee.ImageCollection): ee_obj = ee_obj.map(lambda img: img.unmask(_unmask_sentinel).copyProperties(img, ["system:time_start"])) else: ee_obj = ee.Image(ee_obj).unmask(_unmask_sentinel) # Re-get obj_info after unmask (band structure unchanged) obj_info = get_obj_info(ee_obj, band_names) # Prepare image stacked, stack_bands, x_axis_labels = prepare_for_reduction( ee_obj, obj_info, x_axis_property, date_format ) def _strip_sentinel_cols(df): """Remove sentinel unmask class columns (0, 0.0) from results.""" if _unmask_sentinel is not None: drop = [c for c in df.columns if str(c).strip() in ("0", "0.0")] if drop: df = df.drop(columns=drop) return df if geo_type == "single": raw = reduce_region(stacked, geo, reducer, scale, crs, transform, tile_scale) if is_histogram: return _strip_sentinel_cols( parse_thematic_results(raw, obj_info, x_axis_labels, area_format, scale)) else: return parse_continuous_results(raw, obj_info, x_axis_labels) else: # Multi-region: reduceRegions if is_histogram: # frequencyHistogram + reduceRegions requires special handling: # 1. For single-band images, EE names the output "histogram" instead # of the band name. Use setOutputs() to force band-name keys. # For multi-band images (stacked ImageCollections), EE already # names outputs by band name, and setOutputs() would fail. # 2. robust_featureCollection_to_df flattens nested dicts, destroying # the histogram structure. Get features directly to preserve dicts. if len(stack_bands) == 1: hist_reducer = reducer.setOutputs(stack_bands) else: hist_reducer = reducer if transform is not None and scale is not None: scale = None fc_result = stacked.reduceRegions( collection=geo, reducer=hist_reducer, scale=scale, crs=crs, crsTransform=transform, tileScale=tile_scale, ) # Fetch features preserving nested histogram dicts. # Batch in chunks of 5000 to avoid EE's feature limit. n_features = fc_result.size().getInfo() rows = [] fc_list = fc_result.toList(n_features) batch_size = 5000 for start in range(0, n_features, batch_size): end = min(start + batch_size, n_features) batch = ee.FeatureCollection(fc_list.slice(start, end)) features = batch.getInfo()["features"] for f in features: rows.append(f.get("properties", {})) df = pandas.DataFrame(rows) return _strip_sentinel_cols(_expand_thematic_reduce_regions( df, stack_bands, obj_info["class_info"], area_format, scale, SPLIT_STR )) else: df = reduce_regions(stacked, geo, reducer, scale, crs, transform, tile_scale) return df
[docs] def prepare_sankey_data( ee_collection, band_name, transition_periods, class_info, geometry, scale=30, crs=None, transform=None, tile_scale=4, area_format="Percentage", min_percentage=0.2, ): """ Build a Sankey diagram dataset from class transitions across time periods. For each consecutive pair of periods, this function: 1. Filters the collection to each period 2. Computes the mode for each period 3. Creates a transition image encoding ``{from}0990{to}`` 4. Runs ``frequencyHistogram`` to count transitions 5. Parses results into both a source/target/value DataFrame and a transition matrix DataFrame Args: ee_collection (ee.ImageCollection): The input collection. band_name (str): The thematic band to analyze. transition_periods (list): List of ``[start_year, end_year]`` pairs. class_info (dict): Class info dict for the band (from :func:`get_obj_info`). geometry: ``ee.Geometry`` or ``ee.Feature``. scale (int, optional): Pixel scale in meters. crs (str, optional): CRS string. transform (list, optional): Affine transform. tile_scale (int, optional): Tile scale for parallelism. area_format (str, optional): Area unit. min_percentage (float, optional): Minimum percentage threshold for including a flow in the source-target table. The transition matrix always includes all observed transitions regardless of this threshold. Returns: tuple: ``(sankey_df, matrix_dict)`` - **sankey_df** (``pandas.DataFrame``): Source-target-value table with columns ``source``, ``target``, ``value``, ``source_name``, ``target_name``, ``source_color``, ``target_color``, ``period``. Flows below ``min_percentage`` are excluded. - **matrix_dict** (``dict[str, pandas.DataFrame]``): One transition matrix per consecutive period pair, keyed by ``"{from_period} \u2192 {to_period}"``. Each DataFrame has class names as both row and column labels, with values as converted counts. Examples: Typically called via ``summarize_and_chart(chart_type='sankey')``, but can be used directly for custom sankey workflows: >>> import geeViz.geeView as gv >>> from geeViz.outputLib import charts as cl >>> ee = gv.ee >>> study_area = ee.Geometry.Polygon( ... [[[-106, 39.5], [-105, 39.5], [-105, 40.5], [-106, 40.5]]] ... ) >>> lcms = ee.ImageCollection("USFS/GTAC/LCMS/v2024-10") >>> info = cl.get_obj_info(lcms.select(['Land_Use'])) >>> sankey_df, matrix_dict = cl.prepare_sankey_data( ... lcms.select(['Land_Use']), ... 'Land_Use', ... transition_periods=[[1990, 2000], [2000, 2010], [2010, 2023]], ... class_info=info['class_info']['Land_Use'], ... geometry=study_area, ... scale=30, ... ) >>> print(sankey_df.head().to_markdown()) >>> for label, mdf in matrix_dict.items(): ... print(f"\\n{label}") ... print(mdf.to_markdown()) """ _, geo = detect_geometry_type(geometry) info = class_info.get(band_name, class_info.get(list(class_info.keys())[0], {})) class_values = info.get("class_values", []) class_names = info.get("class_names", []) class_palette = info.get("class_palette", []) value_to_idx = {v: i for i, v in enumerate(class_values)} idx_to_name = {i: n for i, n in enumerate(class_names)} idx_to_color = {i: _ensure_hex_color(c) for i, c in enumerate(class_palette)} num_classes = len(class_values) scale_mult = (scale / AREA_FORMAT_DICT["Hectares"]["scale"]) ** 2 all_rows = [] transition_band_names = [] # Build transition images for each consecutive period pair transition_images = [] period_labels = [] for i in range(len(transition_periods) - 1): p1 = transition_periods[i] p2 = transition_periods[i + 1] p1_start, p1_end = (p1, p1) if not isinstance(p1, (list, tuple)) else (p1[0], p1[-1]) p2_start, p2_end = (p2, p2) if not isinstance(p2, (list, tuple)) else (p2[0], p2[-1]) # Filter and compute mode for each period filtered1 = ee_collection.filter( ee.Filter.calendarRange(int(p1_start), int(p1_end), "year") ).select([band_name]) filtered2 = ee_collection.filter( ee.Filter.calendarRange(int(p2_start), int(p2_end), "year") ).select([band_name]) mode1 = filtered1.mode().rename(["from"]) mode2 = filtered2.mode().rename(["to"]) # Encode transition: from_class * 10000 + 9900 + to_class combined = mode1.addBands(mode2) transition = ( combined.select("from").multiply(10000) .add(9900) .add(combined.select("to")) .rename([f"{_format_period(p1)}---{_format_period(p2)}"]) ) transition_images.append(transition) transition_band_names.append(f"{_format_period(p1)}---{_format_period(p2)}") period_labels.append((_format_period(p1), _format_period(p2))) # Stack all transition images if len(transition_images) == 1: stacked = transition_images[0] else: stacked = transition_images[0] for t_img in transition_images[1:]: stacked = stacked.addBands(t_img) # Run frequency histogram raw = reduce_region( stacked.toInt(), geo, ee.Reducer.frequencyHistogram(), scale, crs, transform, tile_scale ) # Parse results — build both the source-target table and per-period matrices matrix_dict = {} # {period_label: DataFrame} — one matrix per transition pair for ti, t_bn in enumerate(transition_band_names): histogram = raw.get(t_bn, {}) if histogram is None: histogram = {} pixel_total = sum(histogram.values()) or 1 p1_label, p2_label = period_labels[ti] offset1 = ti * num_classes offset2 = (ti + 1) * num_classes # Build count_lookup: (from_idx, to_idx) -> display_val for ALL transitions count_lookup = {} for encoded_str, count in histogram.items(): encoded = int(float(encoded_str)) from_class = encoded // 10000 to_class = encoded % 10000 - 9900 from_idx = value_to_idx.get(from_class) to_idx = value_to_idx.get(to_class) if from_idx is None or to_idx is None: continue # Compute display value pct = (count / pixel_total) * 100 if area_format == "Percentage": display_val = round(pct, 2) elif area_format == "Pixels": display_val = count else: mult = AREA_FORMAT_DICT[area_format]["mult"] * scale_mult display_val = round(count * mult, AREA_FORMAT_DICT[area_format]["places"]) count_lookup[(from_idx, to_idx)] = display_val # Source-target table: only include flows above min_percentage if pct >= min_percentage: all_rows.append( { "source": from_idx + offset1, "target": to_idx + offset2, "value": display_val, "source_name": f"{p1_label} {idx_to_name.get(from_idx, str(from_class))}", "target_name": f"{p2_label} {idx_to_name.get(to_idx, str(to_class))}", "source_color": idx_to_color.get(from_idx, "#888888"), "target_color": idx_to_color.get(to_idx, "#888888"), "period": f"{p1_label} -> {p2_label}", } ) # Build transition matrix for this period pair # Columns are "to" class names, rows are "from" class names period_rows = [] for fi in range(num_classes): row_label = idx_to_name.get(fi, str(fi)) row_data = {"": row_label} for ti2 in range(num_classes): col_label = idx_to_name.get(ti2, str(ti2)) row_data[col_label] = count_lookup.get((fi, ti2), 0) period_rows.append(row_data) period_key = f"{p1_label} \u2192 {p2_label}" if period_rows: mdf = pandas.DataFrame(period_rows).set_index("") mdf.index.name = None matrix_dict[period_key] = mdf # Build sankey_df empty_cols = ["source", "target", "value", "source_name", "target_name", "source_color", "target_color", "period"] if not all_rows: sankey_df = pandas.DataFrame(columns=empty_cols) else: sankey_df = pandas.DataFrame(all_rows) return (sankey_df, matrix_dict)
########################################################################### # Chart functions ###########################################################################
[docs] def chart_time_series( df, colors=None, chart_type="line+markers", title="Time Series", x_label="Year", y_label=None, width=DEFAULT_CHART_WIDTH, height=DEFAULT_CHART_HEIGHT, label_max_length=30, legend_position="right", line_width=2, marker_size=5, max_x_tick_labels=10, max_y_tick_labels=None, ): """ Create a Plotly time series chart from a zonal stats DataFrame. Args: df (pandas.DataFrame): Output of :func:`zonal_stats` for an ImageCollection. Index = x-axis labels, columns = data series. colors (list, optional): Hex color strings for each column. chart_type (str, optional): ``"line+markers"`` (default), ``"line"``, ``"bar"``, ``"stacked_line"``, ``"stacked_line+markers"``, or ``"stacked_bar"``. title (str, optional): Chart title. x_label (str, optional): X-axis label. y_label (str, optional): Y-axis label. width (int, optional): Chart width in pixels. height (int, optional): Chart height in pixels. label_max_length (int, optional): Max characters for legend labels. legend_position (dict or str, optional): Plotly legend layout dict (e.g. ``{"orientation": "h", "x": 0.5, "y": -0.1}``), or ``"right"``/``None`` for the Plotly default. line_width (int or float, optional): Line width in pixels for line/scatter traces. Defaults to ``2``. marker_size (int or float, optional): Marker diameter in pixels for traces that include markers. Defaults to ``5``. max_x_tick_labels (int, optional): Maximum number of x-axis tick labels to display. When the number of x values exceeds this, labels are thinned to every 2nd, 5th, 10th, etc. value. Defaults to ``10``. Set to ``None`` or ``0`` to disable. max_y_tick_labels (int, optional): Maximum number of y-axis tick labels. Uses Plotly's ``nticks``. Defaults to ``None`` (automatic). Returns: plotly.graph_objects.Figure Examples: Build a time series chart from a zonal_stats DataFrame: >>> import geeViz.geeView as gv >>> from geeViz.outputLib import charts as cl >>> ee = gv.ee >>> study_area = ee.Geometry.Polygon( ... [[[-106, 39.5], [-105, 39.5], [-105, 40.5], [-106, 40.5]]] ... ) >>> lcms = ee.ImageCollection("USFS/GTAC/LCMS/v2024-10") >>> # Step 1: get the data >>> info = cl.get_obj_info(lcms.select(['Land_Cover'])) >>> df = cl.zonal_stats( ... lcms.select(['Land_Cover']), study_area, ... ) >>> # Step 2: chart it with class colors >>> colors = info['class_info']['Land_Cover']['class_palette'] >>> fig = cl.chart_time_series( ... df, colors=colors, ... title='LCMS Land Cover', ... y_label='% Area', ... ) >>> fig.show() """ plotly_mode, is_stacked = _parse_chart_type(chart_type) fig = go.Figure() x_values = list(df.index) # Convert pure-integer labels (e.g. years) to int so Plotly uses a # linear axis with automatic tick spacing instead of a categorical axis # that crams every label together. Mirrors the JS parseInt() logic. try: x_values = [int(v) for v in x_values] except (ValueError, TypeError): pass columns = list(df.columns) for i, col in enumerate(columns): color = None if colors and i < len(colors): color = _ensure_hex_color(colors[i]) label = col[:label_max_length] if plotly_mode == "bar": fig.add_trace( go.Bar( x=x_values, y=df[col].values, name=label, marker_color=color, ) ) else: fig.add_trace( go.Scatter( x=x_values, y=df[col].values, mode=plotly_mode, name=label, line=dict(color=color, width=line_width), marker=dict(color=color, size=marker_size), stackgroup="one" if is_stacked else None, ) ) bar_mode = "stack" if is_stacked and plotly_mode == "bar" else ("group" if plotly_mode == "bar" else None) # Determine x tick values — thin if there are too many is_int_x = all(isinstance(v, int) for v in x_values) x_tick_vals = x_values if is_int_x else None if x_tick_vals is not None: thinned = _thin_tick_vals(x_tick_vals, max_x_tick_labels) if thinned is not None: x_tick_vals = thinned # Y-axis: add '%' suffix when label indicates percentage y_kw = dict(title=y_label, automargin=True) if y_label and "%" in y_label: y_kw["ticksuffix"] = "%" if max_y_tick_labels is not None and max_y_tick_labels > 0: y_kw["nticks"] = max_y_tick_labels # Build x-axis kwargs — constrain range to eliminate dead space x_kw = dict(title=x_label, tickangle=45, tickvals=x_tick_vals, tickformat="d" if is_int_x else None) if is_int_x and x_values: x_kw["range"] = [min(x_values) - 0.5, max(x_values) + 0.5] fig.update_layout( title=dict(text=title, x=0.5, xanchor="center"), xaxis=x_kw, yaxis=y_kw, legend=_legend_kwargs(legend_position), plot_bgcolor=DEFAULT_PLOT_BGCOLOR, paper_bgcolor=DEFAULT_PLOT_BGCOLOR, font=dict(family=DEFAULT_PLOT_FONT), width=width, height=height, margin=dict(l=35, r=25, b=50, t=50, pad=5), barmode=bar_mode, hovermode="x unified", ) _themes.apply_plotly_theme(fig, "dark") return fig
[docs] def chart_bar( df, colors=None, title="Class Distribution", y_label=None, max_classes=30, chart_type="bar", width=DEFAULT_CHART_WIDTH, height=DEFAULT_CHART_HEIGHT, legend_position="right", ): """ Create a Plotly bar chart from a single-Image zonal stats DataFrame. Automatically chooses horizontal or vertical orientation based on label length. Args: df (pandas.DataFrame): Output of :func:`zonal_stats` for a single Image. Single row, columns = class names. colors (list, optional): Hex color strings for each bar. title (str, optional): Chart title. y_label (str, optional): Value axis label. max_classes (int, optional): Maximum number of classes to display. width (int, optional): Chart width in pixels. height (int, optional): Chart height in pixels. legend_position (dict or str, optional): Plotly legend layout dict (e.g. ``{"orientation": "h", "x": 0.5, "y": -0.1}``), or ``"right"``/``None`` for the Plotly default. Returns: plotly.graph_objects.Figure Examples: Bar chart of NLCD land cover for a single image: >>> import geeViz.geeView as gv >>> from geeViz.outputLib import charts as cl >>> ee = gv.ee >>> study_area = ee.Geometry.Polygon( ... [[[-106, 39.5], [-105, 39.5], [-105, 40.5], [-106, 40.5]]] ... ) >>> nlcd = ee.ImageCollection( ... "USGS/NLCD_RELEASES/2021_REL/NLCD" ... ).select(['landcover']).mode().set( ... ee.ImageCollection("USGS/NLCD_RELEASES/2021_REL/NLCD") ... .first().toDictionary() ... ) >>> info = cl.get_obj_info(nlcd) >>> df = cl.zonal_stats(nlcd, study_area) >>> colors = info['class_info']['landcover']['class_palette'] >>> fig = cl.chart_bar( ... df, colors=colors, title='NLCD Land Cover', ... ) >>> fig.show() """ # Flatten to series if len(df) == 1: values = df.iloc[0] else: values = df.sum() labels = list(values.index) vals = list(values.values) # Cap at max_classes (keep top N by value) if len(labels) > max_classes: sorted_pairs = sorted(zip(vals, labels, range(len(labels))), reverse=True) sorted_pairs = sorted_pairs[:max_classes] sorted_pairs.sort(key=lambda x: x[2]) # restore original order vals = [p[0] for p in sorted_pairs] labels = [p[1] for p in sorted_pairs] # Also filter colors if colors: idxs = [p[2] for p in sorted_pairs] colors = [_ensure_hex_color(colors[i]) for i in idxs if i < len(colors)] if colors: if len(colors) < len(labels): # Interpolate palette as a continuous ramp (matches JS min/max/palette) colors = _interpolate_palette(colors, len(labels)) else: colors = [_ensure_hex_color(c) for c in colors[:len(labels)]] _, is_stacked = _parse_chart_type(chart_type) # Determine orientation max_label_len = max((len(str(l)) for l in labels), default=0) orientation = "h" if max_label_len > max(len(labels), 6) else "v" fig = go.Figure() if is_stacked: # Stacked bar: one trace per class so barmode="stack" works for i, (lbl, val) in enumerate(zip(labels, vals)): color = colors[i] if colors and i < len(colors) else None if orientation == "h": fig.add_trace(go.Bar( y=[""], x=[val], name=str(lbl), orientation="h", marker_color=color, )) else: fig.add_trace(go.Bar( x=[""], y=[val], name=str(lbl), marker_color=color, )) barmode = "stack" else: # Standard: single trace with per-bar colors if orientation == "h": fig.add_trace(go.Bar( y=labels, x=vals, orientation="h", marker_color=colors, )) else: fig.add_trace(go.Bar( x=labels, y=vals, orientation="v", marker_color=colors, )) barmode = None if orientation == "h": fig.update_layout( xaxis=dict(title=y_label, automargin=True), yaxis=dict(automargin=True), margin=dict(l=80, r=25, b=30, t=50, pad=5), ) else: fig.update_layout( xaxis=dict(tickangle=45, automargin=True), yaxis=dict(title=y_label, automargin=True), margin=dict(l=35, r=25, b=80, t=50, pad=5), ) fig.update_layout( title=dict(text=title, x=0.5, xanchor="center"), legend=_legend_kwargs(legend_position), plot_bgcolor=DEFAULT_PLOT_BGCOLOR, paper_bgcolor=DEFAULT_PLOT_BGCOLOR, font=dict(family=DEFAULT_PLOT_FONT), width=width, height=height, barmode=barmode, hovermode="closest", ) _themes.apply_plotly_theme(fig, "dark") return fig
# --------------------------------------------------------------------------- # Donut chart # ---------------------------------------------------------------------------
[docs] def chart_donut( df, colors=None, title="Class Distribution", max_classes=30, width=DEFAULT_CHART_WIDTH, height=DEFAULT_CHART_HEIGHT, legend_position="right", hole=0.45, ): """Create a Plotly donut chart from a single-Image zonal stats DataFrame. Only valid for **thematic** (categorical) data from a single ``ee.Image``. Raises ``ValueError`` for continuous data or ``ee.ImageCollection`` inputs. Args: df (pandas.DataFrame): Output of :func:`zonal_stats` for a single Image. Single row, columns = class names, values = area/%. colors (list, optional): Hex colour strings, one per class. title (str, optional): Chart title. max_classes (int, optional): Maximum number of classes to display. Smaller classes are grouped into "Other". Defaults to ``30``. width (int, optional): Chart width in pixels. height (int, optional): Chart height in pixels. legend_position (dict or str, optional): Plotly legend dict or ``"right"`` / ``"bottom"``. hole (float, optional): Size of the centre hole (0–1). Defaults to ``0.45``. Returns: plotly.graph_objects.Figure """ import plotly.graph_objects as go # Flatten to series if len(df) == 1: values = df.iloc[0] else: values = df.sum() labels = list(values.index) vals = list(values.values) # Cap at max_classes — group the rest into "Other" if len(labels) > max_classes: sorted_pairs = sorted(zip(vals, labels, range(len(labels))), reverse=True) top = sorted_pairs[:max_classes] other_val = sum(p[0] for p in sorted_pairs[max_classes:]) top.sort(key=lambda x: x[2]) # restore original order vals = [p[0] for p in top] + [other_val] labels = [p[1] for p in top] + ["Other"] if colors: idxs = [p[2] for p in top] colors = [_ensure_hex_color(colors[i]) for i in idxs if i < len(colors)] + ["#888888"] if colors: if len(colors) < len(labels): colors = _interpolate_palette(colors, len(labels)) else: colors = [_ensure_hex_color(c) for c in colors[:len(labels)]] # Filter out zero-value slices filtered = [(l, v, c) for l, v, c in zip(labels, vals, colors or [None] * len(labels)) if v > 0] if filtered: labels, vals, _colors = zip(*filtered) labels, vals = list(labels), list(vals) if colors: colors = list(_colors) fig = go.Figure(data=[go.Pie( labels=labels, values=vals, hole=hole, marker=dict(colors=colors) if colors else {}, textinfo="percent", hoverinfo="label+value+percent", textfont=dict(size=12), )]) fig.update_layout( title=dict(text=title, x=0.5, xanchor="center"), legend=_legend_kwargs(legend_position), plot_bgcolor=DEFAULT_PLOT_BGCOLOR, paper_bgcolor=DEFAULT_PLOT_BGCOLOR, font=dict(family=DEFAULT_PLOT_FONT), width=width, height=height, margin=dict(l=10, r=10, b=10, t=40, pad=0), ) _themes.apply_plotly_theme(fig, "dark") return fig
[docs] def chart_donut_multi_feature( df, colors=None, title="Class Distribution by Feature", max_classes=30, width=DEFAULT_CHART_WIDTH, height=DEFAULT_CHART_HEIGHT, columns=2, legend_position="bottom", hole=0.45, ): """Create a subplot grid of donut charts, one per feature. For multi-feature ``reduceRegions`` output where the DataFrame index is the feature label and columns are class names. Args: df (pandas.DataFrame): Output of :func:`zonal_stats` with ``feature_label`` set. Index = feature names, columns = class names, values = area/%. colors (list, optional): Hex colour strings, one per class. title (str, optional): Overall chart title. max_classes (int, optional): Max classes per donut. width (int, optional): Chart width in pixels. height (int, optional): Chart height in pixels. columns (int, optional): Number of subplot columns. legend_position (dict or str, optional): Legend position. hole (float, optional): Centre hole size. Returns: plotly.graph_objects.Figure """ from plotly.subplots import make_subplots import plotly.graph_objects as go feature_names = list(df.index) n_features = len(feature_names) n_cols = min(columns, n_features) n_rows = -(-n_features // n_cols) # ceil division fig = make_subplots( rows=n_rows, cols=n_cols, specs=[[{"type": "pie"}] * n_cols for _ in range(n_rows)], subplot_titles=feature_names, ) class_labels = list(df.columns) # Prepare colors pal = None if colors: if len(colors) < len(class_labels): pal = _interpolate_palette(colors, len(class_labels)) else: pal = [_ensure_hex_color(c) for c in colors[:len(class_labels)]] for idx, feat_name in enumerate(feature_names): row_i = idx // n_cols + 1 col_i = idx % n_cols + 1 vals = list(df.loc[feat_name]) # Filter zero-value slices filtered = [(l, v, c) for l, v, c in zip(class_labels, vals, pal or [None] * len(class_labels)) if v > 0] if filtered: f_labels, f_vals, f_colors = zip(*filtered) f_labels, f_vals = list(f_labels), list(f_vals) if pal: f_colors = list(f_colors) else: f_colors = None else: f_labels, f_vals, f_colors = class_labels, vals, pal fig.add_trace( go.Pie( labels=f_labels, values=f_vals, hole=hole, marker=dict(colors=f_colors) if f_colors else {}, textinfo="percent", hoverinfo="label+value+percent", textfont=dict(size=11), showlegend=(idx == 0), # legend from first trace only name=feat_name, ), row=row_i, col=col_i, ) # Scale figure size by grid fig_w = width * n_cols fig_h = height * n_rows fig.update_layout( title=dict(text=title, x=0.5, xanchor="center"), legend=_legend_kwargs(legend_position), plot_bgcolor=DEFAULT_PLOT_BGCOLOR, paper_bgcolor=DEFAULT_PLOT_BGCOLOR, font=dict(family=DEFAULT_PLOT_FONT), width=fig_w, height=fig_h, ) _themes.apply_plotly_theme(fig, "dark") return fig
# --------------------------------------------------------------------------- # Scatter chart # ---------------------------------------------------------------------------
[docs] def chart_scatter( df, x_band, y_band, feature_label=None, title="Scatter Plot", width=DEFAULT_CHART_WIDTH, height=DEFAULT_CHART_HEIGHT, legend_position="right", trendline=True, opacity=0.7, show_labels=None, thematic_col=None, class_names=None, class_palette=None, class_values=None, ): """Create a scatter plot of two bands across features. Each point represents one feature (e.g. a county, fire perimeter, or watershed). The x- and y-axes show the mean (or other reduced) value of two image bands over that feature. When *thematic_col* is provided, points are colored by the thematic class value in that column, using the class palette and names from image properties. Args: df (pandas.DataFrame): DataFrame with at least two numeric columns for the x and y bands. Optionally a *thematic_col* column with integer class values. x_band (str): Column name for the x-axis. y_band (str): Column name for the y-axis. feature_label (str, optional): Name of the index (used in hover). title (str, optional): Chart title. width (int, optional): Chart width in pixels. height (int, optional): Chart height in pixels. legend_position (dict or str, optional): Legend position. trendline (bool, optional): Draw a linear trendline. Defaults to ``True``. opacity (float, optional): Point opacity (0-1). Lower values help visualize overlapping points. Defaults to ``0.7``. show_labels (bool, optional): Label each point with the feature name. When ``None`` (default), labels are shown only when the DataFrame has fewer than 30 rows. thematic_col (str, optional): Column containing thematic class values used to color each point. Defaults to ``None``. class_names (list, optional): Class name strings matching *class_values*. class_palette (list, optional): Hex colour strings matching *class_values*. class_values (list, optional): Integer class values that map to *class_names* and *class_palette*. Returns: plotly.graph_objects.Figure """ import plotly.graph_objects as go import numpy as np x_vals = df[x_band].values.astype(float) y_vals = df[y_band].values.astype(float) labels = list(df.index) # Auto-decide whether to show text labels if show_labels is None: show_labels = len(df) < 30 mode = "markers+text" if show_labels else "markers" marker_size = 8 if len(df) > 50 else 10 fig = go.Figure() # --- Thematic color: one trace per class for legend --- if thematic_col is not None and thematic_col in df.columns: cat_vals = df[thematic_col].values # Build lookup: class_value -> (name, color) _val_to_name = {} _val_to_color = {} if class_values and class_names: for v, n in zip(class_values, class_names): _val_to_name[v] = n if class_values and class_palette: for v, c in zip(class_values, class_palette): _val_to_color[v] = _ensure_hex_color(c) unique_classes = sorted(set(int(v) for v in cat_vals if np.isfinite(v))) for cls_val in unique_classes: mask = cat_vals == cls_val cls_name = _val_to_name.get(cls_val, str(cls_val)) cls_color = _val_to_color.get(cls_val, None) cls_x = x_vals[mask] cls_y = y_vals[mask] cls_labels = [labels[i] for i, m in enumerate(mask) if m] hover_parts = [] if show_labels: hover_parts.append("<b>%{text}</b><br>") hover_parts.append( f"{x_band}: %{{x:.2f}}<br>" f"{y_band}: %{{y:.2f}}<br>" f"{thematic_col}: {cls_name}" "<extra></extra>" ) fig.add_trace(go.Scatter( x=cls_x, y=cls_y, mode=mode, name=cls_name, text=cls_labels if show_labels else None, textposition="top center" if show_labels else None, textfont=dict(size=9) if show_labels else None, marker=dict( size=marker_size, color=cls_color, opacity=opacity, line=dict(width=0.5, color="#333"), ), hovertemplate="".join(hover_parts), legendgroup=cls_name, )) else: # --- Single color (no thematic) --- hover_parts = [] if show_labels: hover_parts.append("<b>%{text}</b><br>") else: hover_parts.append("<b>Point %{pointNumber}</b><br>") hover_parts.append( f"{x_band}: %{{x:.2f}}<br>" f"{y_band}: %{{y:.2f}}" "<extra></extra>" ) fig.add_trace(go.Scatter( x=x_vals, y=y_vals, mode=mode, text=labels if show_labels else None, textposition="top center" if show_labels else None, textfont=dict(size=10) if show_labels else None, marker=dict( size=marker_size, color="#66c2a5", opacity=opacity, line=dict(width=0.5, color="#333"), ), hovertemplate="".join(hover_parts), showlegend=False, )) # Trendline if trendline and len(x_vals) > 1: mask = np.isfinite(x_vals) & np.isfinite(y_vals) if mask.sum() > 1: coeffs = np.polyfit(x_vals[mask], y_vals[mask], 1) x_line = np.linspace(x_vals[mask].min(), x_vals[mask].max(), 50) y_line = np.polyval(coeffs, x_line) r_sq = np.corrcoef(x_vals[mask], y_vals[mask])[0, 1] ** 2 fig.add_trace(go.Scatter( x=x_line, y=y_line, mode="lines", line=dict(color="#fc8d62", width=2, dash="dash"), name=f"R\u00b2 = {r_sq:.3f}", showlegend=True, )) fig.update_layout( title=dict(text=title, x=0.5, xanchor="center"), xaxis_title=x_band, yaxis_title=y_band, legend=_legend_kwargs(legend_position), plot_bgcolor=DEFAULT_PLOT_BGCOLOR, paper_bgcolor=DEFAULT_PLOT_BGCOLOR, font=dict(family=DEFAULT_PLOT_FONT), width=width, height=height, hovermode="closest", ) _themes.apply_plotly_theme(fig, "dark") return fig
[docs] def chart_grouped_bar( df, colors=None, title="Zonal Summary by Feature", y_label=None, chart_type="bar", width=DEFAULT_CHART_WIDTH, height=DEFAULT_CHART_HEIGHT, legend_position="right", ): """ Create a grouped (or stacked) bar chart for multi-feature zonal stats. Each group on the x-axis is a feature (row) and each bar/segment within the group is a class (column). This is the natural chart type when ``reduceRegions`` returns one row per zone. Args: df (pandas.DataFrame): Rows = features (index used as labels), columns = class names, values = numeric area/percentage. colors (list, optional): Hex color strings, one per column (class). title (str, optional): Chart title. y_label (str, optional): Y-axis label. stacked (bool, optional): Stack bars instead of grouping. Defaults to False. width (int, optional): Chart width in pixels. height (int, optional): Chart height in pixels. legend_position (dict or str, optional): Plotly legend layout dict (e.g. ``{"orientation": "h", "x": 0.5, "y": -0.1}``), or ``"right"``/``None`` for the Plotly default. Returns: plotly.graph_objects.Figure Examples: Compare land cover across the 5 largest MTBS fire perimeters: >>> import geeViz.geeView as gv >>> from geeViz.outputLib import charts as cl >>> ee = gv.ee >>> lcms = ee.ImageCollection("USFS/GTAC/LCMS/v2024-10") >>> fires = ee.FeatureCollection( ... "USFS/GTAC/MTBS/burned_area_boundaries/v1" ... ).sort("BurnBndAc", False).limit(5) >>> lc_mode = lcms.select(["Land_Cover"]).mode().set( ... lcms.first().toDictionary() ... ) >>> # summarize_and_chart handles reduceRegions + grouped bar: >>> df, fig = cl.summarize_and_chart( ... lc_mode, fires, ... feature_label="Incid_Name", ... title="Land Cover — 5 Largest Fires", ... stacked=True, width=800, ... ) >>> fig.show() """ fig = go.Figure() feature_labels = [str(v) for v in df.index] for i, col in enumerate(df.columns): color = None if colors and i < len(colors) and colors[i] is not None: color = _ensure_hex_color(colors[i]) fig.add_trace( go.Bar( name=str(col), x=feature_labels, y=df[col].values, marker_color=color, ) ) fig.update_layout( barmode="stack" if chart_type in ("stacked_bar", "stacked") else "group", title=dict(text=title, x=0.5, xanchor="center"), xaxis=dict(title="Feature", tickangle=45, automargin=True), yaxis=dict(title=y_label or "", automargin=True), legend=_legend_kwargs(legend_position), plot_bgcolor=DEFAULT_PLOT_BGCOLOR, paper_bgcolor=DEFAULT_PLOT_BGCOLOR, font=dict(family=DEFAULT_PLOT_FONT), width=width, height=height, margin=dict(l=35, r=25, b=80, t=50, pad=5), hovermode="x unified", ) _themes.apply_plotly_theme(fig, "dark") return fig
[docs] def chart_sankey_d3( sankey_df, class_names, class_palette, transition_periods, title="Class Transitions", width=DEFAULT_CHART_WIDTH, height=DEFAULT_CHART_HEIGHT, node_thickness=20, node_pad=15, opacity=0.9, theme="dark", bg_color=None, font_color=None, hide_toolbar=False, ): """Create a D3 Sankey diagram directly from transition data — no Plotly. Builds a self-contained HTML string with native SVG ``linearGradient`` elements so each link fades from its source node color to its target node color. Uses ``d3-sankey`` for layout. This is the preferred rendering path for Sankey charts. Unlike :func:`chart_sankey` (which builds a Plotly figure that must be post-processed by :func:`sankey_to_html` for gradients), this function goes straight from the raw ``sankey_df`` to D3 HTML. Args: sankey_df (pandas.DataFrame): Output of :func:`prepare_sankey_data`. Columns: ``source``, ``target``, ``value``, ``source_name``, ``target_name``, ``source_color``, ``target_color``. class_names (list): List of class names. class_palette (list): List of hex color strings. transition_periods (list): Period list (for node labeling). title (str, optional): Chart title. Defaults to ``"Class Transitions"``. width (int, optional): Chart width in pixels. height (int, optional): Chart height in pixels. node_thickness (int, optional): Sankey node bar thickness. node_pad (int, optional): Padding between Sankey nodes. opacity (float, optional): Link opacity (0-1). Defaults to 0.9. theme (str, optional): Theme preset. Defaults to ``"dark"``. bg_color (str, optional): Background color override. font_color (str, optional): Font color override. hide_toolbar (bool, optional): Hide the download button. Returns: str: Self-contained HTML string with embedded D3 Sankey chart. Examples: >>> sankey_df, matrix_dict = cl.prepare_sankey_data( ... lcms.select(['Land_Use']), 'Land_Use', ... transition_periods=[1990, 2005, 2023], ... class_info=info['class_info'], geometry=study_area, ... ) >>> html = cl.chart_sankey_d3( ... sankey_df, info['class_info']['Land_Use']['class_names'], ... info['class_info']['Land_Use']['class_palette'], ... transition_periods=[1990, 2005, 2023], ... ) """ import json as _json _t = _themes.get_theme(theme, bg_color=bg_color, font_color=font_color) if sankey_df.empty: return f"<html><body style='background:{_t.bg_hex};color:{_t.text_hex}'><p>No transitions found</p></body></html>" # Build node labels and hex colors for all period slots num_classes = len(class_names) labels = [] node_colors_hex = [] for p in transition_periods: p_label = _format_period(p) for i, name in enumerate(class_names): labels.append(f"{p_label} {name}") node_colors_hex.append( _ensure_hex_color(class_palette[i]) if i < len(class_palette) else "#888888" ) # Build used-node set and remap indices (skip orphan nodes) used_indices = set() for _, row in sankey_df.iterrows(): if row["value"] > 0: used_indices.add(int(row["source"])) used_indices.add(int(row["target"])) old_to_new = {} new_idx = 0 for old_idx in range(len(labels)): if old_idx in used_indices: old_to_new[old_idx] = new_idx new_idx += 1 d3_data = { "nodes": [ {"name": labels[i], "color": node_colors_hex[i]} for i in range(len(labels)) if i in used_indices ], "links": [ { "source": old_to_new[int(row["source"])], "target": old_to_new[int(row["target"])], "value": float(row["value"]), "sourceColor": _ensure_hex_color(row.get("source_color", "#888")), "targetColor": _ensure_hex_color(row.get("target_color", "#888")), } for _, row in sankey_df.iterrows() if row["value"] > 0 and int(row["source"]) in old_to_new and int(row["target"]) in old_to_new ], } d3_config = { "title": title, "width": width, "height": height, "nodeWidth": node_thickness, "nodePadding": node_pad, "opacity": opacity, "bgColor": _t.bg_hex, "textColor": _t.text_hex, } html = _render_d3_sankey(_t) result = html.replace( "__D3_DATA_JSON__", _json.dumps(d3_data) ).replace( "__D3_CONFIG_JSON__", _json.dumps(d3_config) ) if hide_toolbar: result = result.replace( '<div id="toolbar">', '<div id="toolbar" style="display:none">', ) return result
[docs] def sankey_iframe(sankey_html, width=None, height=None): """Wrap sankey D3 HTML in an iframe for Jupyter notebook display. Jupyter sanitizes ``<script>`` tags in ``display(HTML(...))``, so D3 sankey charts must be embedded in an iframe. Uses a ``data:text/html;base64`` src for maximum compatibility across Jupyter environments (classic notebook, JupyterLab, VS Code). Args: sankey_html (str): Full HTML string from :func:`chart_sankey_d3` or ``summarize_and_chart(chart_type='sankey')``. width (int, optional): Iframe width in pixels. Auto-detected from the HTML when ``None``. height (int, optional): Iframe height in pixels. Auto-detected from the HTML when ``None``. Returns: str: HTML ``<iframe>`` element suitable for ``display(HTML(...))``. Example: >>> from IPython.display import HTML, display >>> display(HTML(cl.sankey_iframe(sankey_html))) """ import base64, re if width is None: m = re.search(r'"width"\s*:\s*(\d+)', sankey_html) width = int(m.group(1)) + 50 if m else 900 if height is None: m = re.search(r'"height"\s*:\s*(\d+)', sankey_html) height = int(m.group(1)) + 80 if m else 650 b64 = base64.b64encode(sankey_html.encode("utf-8")).decode("ascii") return ( f'<iframe src="data:text/html;base64,{b64}" ' f'style="width:{width}px;height:{height}px;border:none;overflow:hidden;">' f'</iframe>' )
########################################################################### # Convenience function ###########################################################################
[docs] def summarize_and_chart( ee_obj, geometry, band_names=None, reducer=None, scale=30, crs=None, transform=None, tile_scale=4, area_format="Percentage", x_axis_property="system:time_start", date_format="YYYY", title=None, chart_type=None, sankey=False, transition_periods=None, sankey_band_name=None, min_percentage=0.2, palette=None, feature_label=None, width=DEFAULT_CHART_WIDTH, height=DEFAULT_CHART_HEIGHT, opacity=0.9, legend_position="right", columns=2, include_masked_area=True, stacked=None, # deprecated — use chart_type instead thematic_band_name=None, line_width=2, marker_size=5, class_visible=None, max_x_tick_labels=10, max_y_tick_labels=None, ): """ Run zonal statistics and produce a chart in one call. Orchestrates :func:`zonal_stats` (or :func:`prepare_sankey_data`) and the appropriate chart function. The chart type is chosen automatically: * **ee.ImageCollection** -> **line chart** (default ``"line+markers"``). * **ee.Image** -> **bar chart** (default ``"bar"``). * **chart_type="donut"** -> **donut chart** (Image + thematic only). * **chart_type="scatter"** -> **scatter plot** (Image + FeatureCollection only; uses 2 continuous bands as x/y axes, optionally coloured by *thematic_band_name*). * **chart_type="sankey"** -> **Sankey transition diagram**. * **feature_label** + ``ee.FeatureCollection`` + ``ee.Image`` -> **grouped bar** or **per-feature donut** chart. * **feature_label** + ``ee.FeatureCollection`` + ``ee.ImageCollection`` -> **per-feature time series subplots**. Args: ee_obj: ``ee.Image`` or ``ee.ImageCollection``. geometry: ``ee.Geometry``, ``ee.Feature``, or ``ee.FeatureCollection``. band_names (list, optional): Bands to include. reducer (ee.Reducer, optional): Override the auto-selected reducer. scale (int, optional): Pixel scale in meters. crs (str, optional): CRS string. transform (list, optional): Affine transform. tile_scale (int, optional): Tile scale for parallelism. area_format (str, optional): Area unit for thematic data. x_axis_property (str, optional): Property for x-axis labels. date_format (str, optional): Date format string. title (str, optional): Chart title. Auto-generated if None. chart_type (str, optional): Chart type. One of ``"bar"``, ``"stacked_bar"``, ``"donut"`` (Image + thematic only), ``"scatter"`` (Image + FeatureCollection only), ``"sankey"`` (ImageCollection + thematic, requires ``transition_periods``), ``"line"``, ``"stacked_line"``, ``"line+markers"`` (default for ImageCollection), or ``"stacked_line+markers"``. Defaults to ``"bar"`` for single ``ee.Image``, ``"line+markers"`` for ``ee.ImageCollection``. stacked (bool, optional): **Deprecated** — use ``chart_type`` instead. When ``True``, prepends ``"stacked_"`` to ``chart_type``. Defaults to ``None``. sankey (bool, optional): Deprecated — use ``chart_type='sankey'`` instead. Still accepted for backward compatibility. transition_periods (list, optional): Period list for Sankey. sankey_band_name (str, optional): Band for Sankey analysis. min_percentage (float, optional): Minimum percentage for Sankey flows. palette (list, optional): Hex color strings for each series/band. Overrides auto-detected class palette when provided. feature_label (str, optional): Property name to use as row labels when the geometry is a multi-feature ``ee.FeatureCollection``. Triggers the ``reduceRegions`` path. For ``ee.Image`` input produces a grouped bar chart; for ``ee.ImageCollection`` input produces per-feature time series subplots. width (int, optional): Chart width in pixels (per cell for multi-feature subplots). height (int, optional): Chart height in pixels (per cell for multi-feature subplots). opacity (float, optional): Opacity for Sankey nodes and links (0-1). Defaults to 0.9. legend_position (dict or str, optional): Plotly legend layout dict for non-Sankey charts (e.g. ``{"orientation": "h", "x": 0.5, "y": -0.1}``), or ``"right"``/``None`` for the Plotly default. columns (int, optional): Number of subplot columns for multi-feature time series. Total width/height scale to ``n_cols * width`` / ``n_rows * height``. Defaults to 2. include_masked_area (bool, optional): When ``True`` (default) and using the histogram reducer, unmasked pixels with value 0 are included so percentages are relative to the total area, not just the unmasked portion. The sentinel class is removed from results. thematic_band_name (str, optional): For ``chart_type="scatter"`` only. Name of a thematic band in the image whose mode value per feature is used to colour each scatter point. The image must carry ``{band}_class_values``, ``{band}_class_names``, and ``{band}_class_palette`` properties for the colours and legend entries. Defaults to ``None`` (single-colour points). line_width (int or float, optional): Line width in pixels for time series traces. Defaults to ``2``. marker_size (int or float, optional): Marker diameter in pixels for time series traces. Defaults to ``5``. class_visible (dict, optional): Per-class visibility control. Maps class names to booleans. Classes set to ``False`` are toggled off in the chart legend (set to ``"legendonly"``). The traces remain in the figure — users can click the legend to re-enable them. Useful for hiding background, no-data, or stable classes by default. Works for all chart paths including single-geometry, multi-feature time series subplots, and multi-feature bar/donut charts. Example:: class_visible={ "Non-Processing Area Mask": False, "Stable": False, "Background": False, } When ``None`` (default), all classes are visible. max_x_tick_labels (int, optional): Maximum number of x-axis tick labels. When the data has more x values than this, tick labels are thinned to every 2nd, 5th, 10th, etc. value. Defaults to ``10``. Set to ``None`` or ``0`` to show all. max_y_tick_labels (int, optional): Maximum number of y-axis tick labels. Passed as Plotly's ``nticks``. Defaults to ``None`` (Plotly automatic). Returns: tuple: Depends on chart type: * **Standard (single geometry):** ``(DataFrame, Figure)`` * **Sankey:** ``(sankey_df, sankey_html, matrix_dict)`` where ``sankey_html`` is a D3 HTML string (display with ``display(HTML(cl.sankey_iframe(sankey_html)))``), and ``matrix_dict`` is ``{period_label: DataFrame}`` * **Multi-feature + ee.Image (bar/donut):** ``(DataFrame, Figure)`` * **Multi-feature + ee.ImageCollection:** ``(dict, Figure)`` where ``dict`` is ``{feature_name: DataFrame}`` * **Scatter:** ``(DataFrame, Figure)`` where the DataFrame has columns for the two bands (and optionally the thematic band) Examples: Stacked time series of thematic land cover (auto-detects class properties from the image collection): >>> import geeViz.geeView as gv >>> from geeViz.outputLib import charts as cl >>> ee = gv.ee >>> study_area = ee.Geometry.Polygon( ... [[[-106, 39.5], [-105, 39.5], [-105, 40.5], [-106, 40.5]]] ... ) >>> lcms = ee.ImageCollection("USFS/GTAC/LCMS/v2024-10") >>> df, fig = cl.summarize_and_chart( ... lcms.select(['Land_Cover']), ... study_area, ... title='LCMS Land Cover', ... stacked=True, ... ) >>> print(df.to_markdown()) >>> fig.write_html("lcms_land_cover.html", include_plotlyjs="cdn") Sankey transition diagram with D3 gradient-colored links: >>> df, sankey_html, matrix = cl.summarize_and_chart( ... lcms.select(['Land_Use']), ... study_area, ... chart_type='sankey', ... transition_periods=[1990, 2000, 2024], ... sankey_band_name='Land_Use', ... min_percentage=0.5, ... ) >>> # In notebooks: display(HTML(cl.sankey_iframe(sankey_html))) >>> # Save to file: >>> cl.save_chart_html(sankey_html, "land_use_transitions.html") Bar chart for a single image at a point (use ``ee.Reducer.first()``): >>> nlcd = ee.Image("USGS/NLCD_RELEASES/2021_REL/NLCD/2021") >>> point = ee.Geometry.Point([-104.99, 39.74]) >>> df, fig = cl.summarize_and_chart( ... nlcd, ... point, ... reducer=ee.Reducer.first(), ... scale=30, ... title='NLCD Land Cover', ... ) Continuous time series (non-thematic bands auto-select ``ee.Reducer.mean()``): >>> import geeViz.getImagesLib as gil >>> composites = gil.getLandsatWrapper( ... study_area, 2000, 2024 ... )['composites'] >>> df, fig = cl.summarize_and_chart( ... composites, ... study_area, ... band_names=['nir', 'swir1', 'swir2'], ... title='Spectral Band Means', ... palette=['D0D', '0DD', 'DD0'], ... ) Grouped bar chart comparing multiple features (uses reduceRegions internally): >>> fires = ee.FeatureCollection( ... "USFS/GTAC/MTBS/burned_area_boundaries/v1" ... ) >>> top5 = fires.sort("BurnBndAc", False).limit(5) >>> lc_mode = lcms.select(["Land_Cover"]).mode().set( ... lcms.first().toDictionary() ... ) >>> df, fig = cl.summarize_and_chart( ... lc_mode, ... top5, ... feature_label="Incid_Name", ... title="Land Cover — 5 Largest MTBS Fires", ... stacked=True, ... width=800, ... ) Thematic data without class properties — force frequencyHistogram or set properties on-the-fly: >>> lcpri = ee.ImageCollection( ... "projects/sat-io/open-datasets/LCMAP/LCPRI" ... ).select(['b1'], ['LC']) >>> # Force thematic (class values used as labels): >>> df, fig = cl.summarize_and_chart( ... lcpri, ... study_area, ... reducer=ee.Reducer.frequencyHistogram(), ... title='LCMAP LC Primary', ... ) >>> # Or set properties for proper names and colors: >>> lcpri_named = lcpri.map(lambda img: img.set({ ... 'LC_class_values': list(range(1, 10)), ... 'LC_class_names': ['Developed', 'Cropland', 'Grass/Shrub', ... 'Tree Cover', 'Water', 'Wetlands', 'Ice/Snow', ... 'Barren', 'Class Change'], ... 'LC_class_palette': ['E60000', 'A87000', 'E3E3C2', '1D6330', ... '476BA1', 'BAD9EB', 'FFFFFF', 'B3B0A3', 'A201FF'], ... })) >>> df, fig = cl.summarize_and_chart( ... lcpri_named, study_area, stacked=True, ... ) Switch area format to hectares or acres: >>> df_ha, fig_ha = cl.summarize_and_chart( ... lcms.select(['Land_Cover']), ... study_area, ... area_format='Hectares', ... title='LCMS Land Cover (Hectares)', ... ) """ # filterBounds only applies to ImageCollections, not single Images if isinstance(ee_obj, ee.ImageCollection): ee_obj = ee_obj.filterBounds(geometry) obj_info = get_obj_info(ee_obj, band_names) class_info = obj_info["class_info"] if obj_info["is_thematic"]: y_label = AREA_FORMAT_DICT.get(area_format, {}).get("label", area_format) else: # Auto-derive y_label from reducer for continuous data _reducer_labels = { "mean": "Mean", "median": "Median", "mode": "Mode", "sum": "Sum", "min": "Min", "max": "Max", "first": "Value", "stddev": "Std Dev", "count": "Count", "variance": "Variance", } y_label = None if reducer is not None: try: r_str = str(reducer.serialize()).lower() for key, label in _reducer_labels.items(): if key in r_str: y_label = label break except Exception: pass if y_label is None: # Default reducer is mean for continuous data y_label = "Mean" # --- Resolve chart_type --- # Backward compat: if old `stacked=True` was passed, merge into chart_type if stacked is not None and stacked and chart_type is None: chart_type = "stacked_line+markers" elif stacked is not None and stacked and chart_type is not None: # stacked=True + explicit chart_type → prepend stacked_ if not already ct = str(chart_type) if not ct.startswith("stacked_"): chart_type = f"stacked_{ct}" # Default chart_type based on object type if chart_type is None: if obj_info["obj_type"] == "ImageCollection": chart_type = "line+markers" else: chart_type = "bar" # Donut validation — Image-only and thematic-only if str(chart_type).lower().strip() == "donut": if obj_info["obj_type"] == "ImageCollection": raise ValueError( "chart_type='donut' is only supported for ee.Image inputs, " "not ee.ImageCollection. Use chart_type='bar', 'stacked_bar', 'line', 'line+markers', 'stacked_line' or 'stacked_line+markers' for " "ImageCollections." ) if not obj_info.get("is_thematic") and not class_info: raise ValueError( "chart_type='donut' is only supported for thematic " "(categorical) data with class names and palette properties. " "Use chart_type='bar', 'stacked_bar', 'line', 'line+markers', 'stacked_line' or 'stacked_line+markers' for continuous data." ) # Scatter validation — Image + FeatureCollection only if str(chart_type).lower().strip() == "scatter": if obj_info["obj_type"] == "ImageCollection": raise ValueError( "chart_type='scatter' is only supported for ee.Image inputs, " "not ee.ImageCollection." ) geo_type_check, _ = detect_geometry_type(geometry) if geo_type_check != "multi": raise ValueError( "chart_type='scatter' requires a multi-feature " "ee.FeatureCollection as the geometry input (one point per " "feature). Pass a FeatureCollection with multiple features." ) # Sankey path — chart_type='sankey' (preferred) or legacy sankey=True if str(chart_type).lower().strip() == "sankey": sankey = True if sankey and obj_info["obj_type"] == "ImageCollection" and class_info: bn = sankey_band_name or obj_info["band_names"][0] if transition_periods is None: raise ValueError("transition_periods is required for Sankey charts") if title is None: title = f"{bn} Class Transitions" sankey_df, matrix_dict = prepare_sankey_data( ee_obj, bn, transition_periods, class_info, geometry, scale=scale, crs=crs, transform=transform, tile_scale=tile_scale, area_format=area_format, min_percentage=min_percentage, ) info = class_info.get(bn, {}) sankey_html = chart_sankey_d3( sankey_df, class_names=info.get("class_names", []), class_palette=info.get("class_palette", []), transition_periods=transition_periods, title=title, width=width, height=height, opacity=opacity, ) return (sankey_df, sankey_html, matrix_dict) # Multi-feature path: reduceRegions geo_type, _ = detect_geometry_type(geometry) # --- Scatter path: Image + FeatureCollection + 2 bands --- # Handled before feature_label gate since scatter works without labels. if geo_type == "multi" and str(chart_type).lower().strip() == "scatter": # Resolve the two bands to plot all_bands = obj_info["band_names"] if band_names and len(band_names) >= 2: x_band, y_band = band_names[0], band_names[1] elif len(all_bands) >= 2: x_band, y_band = all_bands[0], all_bands[1] else: raise ValueError( "chart_type='scatter' requires at least 2 bands. " f"Image only has: {all_bands}" ) # Use mean reducer for scatter (continuous per-feature values) _scatter_reducer = reducer if reducer is not None else ee.Reducer.first() fc = ee.FeatureCollection(geometry) # Reduce continuous bands continuous_img = ee.Image(ee_obj).select([x_band, y_band]) reduced = continuous_img.reduceRegions( collection=fc, reducer=_scatter_reducer, scale=scale, crs=crs, crsTransform=transform, tileScale=tile_scale, ) # If thematic band requested, reduce it separately with mode() # and join the result onto the continuous FC if thematic_band_name: thematic_img = ee.Image(ee_obj).select([thematic_band_name]) reduced_thematic = thematic_img.reduceRegions( collection=fc, reducer=ee.Reducer.mode(), scale=scale, crs=crs, crsTransform=transform, tileScale=tile_scale, ) # Add the mode column to each feature via zip reduced_list = reduced.toList(reduced.size()) thematic_list = reduced_thematic.toList(reduced_thematic.size()) def _merge(i): i = ee.Number(i).int() f = ee.Feature(reduced_list.get(i)) t = ee.Feature(thematic_list.get(i)) return f.set(thematic_band_name, t.get("mode")) reduced = ee.FeatureCollection( ee.List.sequence(0, reduced.size().subtract(1)).map(_merge) ) # Convert to DataFrame import geeViz.gee2Pandas as g2p scatter_df = g2p.robust_featureCollection_to_df(reduced) # Set index to feature label if available if feature_label and feature_label in scatter_df.columns: scatter_df = scatter_df.set_index(feature_label) # Ensure the two band columns exist if x_band not in scatter_df.columns or y_band not in scatter_df.columns: raise ValueError( f"Reduced DataFrame missing expected band columns. " f"Expected '{x_band}' and '{y_band}', got: {list(scatter_df.columns)}" ) if title is None: title = f"{y_band} vs {x_band}" # Resolve thematic class info for coloring _thematic_col = None _class_names = None _class_palette = None _class_values = None if thematic_band_name: if thematic_band_name in scatter_df.columns: _thematic_col = thematic_band_name # Get class metadata from the image if _thematic_col and class_info and thematic_band_name in class_info: ci = class_info[thematic_band_name] _class_names = ci.get("class_names", []) _class_palette = ci.get("class_palette", []) _class_values = ci.get("class_values", []) elif _thematic_col: # Try reading from image properties directly try: props = ee.Image(ee_obj).getInfo().get("properties", {}) _class_values = props.get(f"{thematic_band_name}_class_values") _class_names = props.get(f"{thematic_band_name}_class_names") _class_palette = props.get(f"{thematic_band_name}_class_palette") except Exception: pass # Build output columns out_cols = [x_band, y_band] if _thematic_col and _thematic_col in scatter_df.columns: out_cols.append(_thematic_col) fig = chart_scatter( scatter_df, x_band=x_band, y_band=y_band, feature_label=feature_label, title=title, width=width, height=height, legend_position=legend_position, opacity=opacity, thematic_col=_thematic_col, class_names=_class_names, class_palette=_class_palette, class_values=_class_values, ) return (scatter_df[[c for c in out_cols if c in scatter_df.columns]], _set_download_filename(fig)) # Auto-detect feature_label for multi-feature FeatureCollections if geo_type == "multi" and not feature_label: feature_label = _detect_feature_label(geometry) if geo_type == "multi" and feature_label: df = zonal_stats( ee_obj, geometry, band_names=band_names, reducer=reducer, scale=scale, crs=crs, transform=transform, tile_scale=tile_scale, area_format=area_format, x_axis_property=x_axis_property, date_format=date_format, include_masked_area=include_masked_area, ) # Build color list from class_info (shared by both sub-paths) colors = palette if colors is None and class_info: color_lookup = {} for bn in obj_info["band_names"]: info = class_info.get(bn, {}) cn = info.get("class_names", []) cp = info.get("class_palette", []) for i, name in enumerate(cn): if i < len(cp): color_lookup[name] = cp[i] if color_lookup: # Will be applied per-column below pass else: color_lookup = {} else: color_lookup = {} # --- ImageCollection + multi-feature: per-feature time series --- if obj_info["obj_type"] == "ImageCollection": # Recover x_axis_labels from prepare_for_reduction # (zonal_stats already called it internally; we re-derive from column names) x_axis_labels = [] seen = set() for col in df.columns: if SPLIT_STR in col: prefix = col.split(SPLIT_STR)[0] if prefix not in seen: seen.add(prefix) x_axis_labels.append(prefix) per_feature_dfs = _pivot_multi_feature_timeseries( df, x_axis_labels, obj_info, feature_label, SPLIT_STR ) # Build color list matching column order of per-feature DataFrames if per_feature_dfs: sample_cols = list(next(iter(per_feature_dfs.values())).columns) if colors is None and color_lookup: colors = [_ensure_hex_color(color_lookup.get(c)) if color_lookup.get(c) else None for c in sample_cols] elif colors is None: colors = None if title is None: title = "Time Series by Feature" # Pass per-cell width/height — chart_multi_feature_timeseries # scales to n_cols * width, n_rows * height internally fig = chart_multi_feature_timeseries( per_feature_dfs, colors=colors, chart_type=chart_type, title=title, x_label=( "Date" if x_axis_property == "system:time_start" else (x_axis_property.replace("_", " ").title() if x_axis_property != "year" else "Year") ), y_label=y_label, width=width, height=height, columns=columns, legend_position=legend_position, line_width=line_width, marker_size=marker_size, max_x_tick_labels=max_x_tick_labels, max_y_tick_labels=max_y_tick_labels, ) # Apply class_visible to multi-feature time series subplots if class_visible is not None and isinstance(class_visible, dict): hidden = {name for name, vis in class_visible.items() if not vis} if hidden: for trace in fig.data: trace_name = trace.name or "" if (trace_name in hidden or any(trace_name.endswith(SPLIT_STR + h) for h in hidden) or any(trace_name.replace(SPLIT_STR, " ").strip() in hidden for _ in [0])): trace.visible = "legendonly" return (per_feature_dfs, _set_download_filename(fig)) # --- ee.Image + multi-feature: grouped bar chart (existing behavior) --- # Set index to feature label column if feature_label in df.columns: df = df.set_index(feature_label) # Identify class columns from class_info class_cols = [] if class_info: for bn in obj_info["band_names"]: info = class_info.get(bn, {}) for name in info.get("class_names", []): col_name = f"{bn}{SPLIT_STR}{name}" if len(obj_info["band_names"]) > 1 else name if col_name in df.columns: class_cols.append(col_name) # Fallback: prefer columns matching image band names (avoids picking # up feature properties like ALAND, AWATER from reduceRegions output) if not class_cols: band_col_set = set(obj_info["band_names"]) class_cols = [c for c in df.columns if c in band_col_set] # Last resort: keep all numeric columns that aren't geometry/system properties if not class_cols: class_cols = [ c for c in df.columns if pandas.api.types.is_numeric_dtype(df[c]) and not c.startswith("geometry") and c not in ("system:index",) ] chart_df = df[class_cols].fillna(0) # Build colors for grouped bar if colors is None and class_info: bar_color_lookup = {} for bn in obj_info["band_names"]: info = class_info.get(bn, {}) cn = info.get("class_names", []) cp = info.get("class_palette", []) for i, name in enumerate(cn): col_name = f"{bn}{SPLIT_STR}{name}" if len(obj_info["band_names"]) > 1 else name if i < len(cp): bar_color_lookup[col_name] = cp[i] if bar_color_lookup: colors = [bar_color_lookup.get(col) for col in chart_df.columns] if title is None: title = "Zonal Summary by Feature" if str(chart_type).lower().strip() == "donut": fig = chart_donut_multi_feature( chart_df, colors=colors, title=title, width=width, height=height, columns=columns, legend_position=legend_position, ) else: fig = chart_grouped_bar( chart_df, colors=colors, title=title, y_label=y_label, chart_type=chart_type, width=width, height=height, legend_position=legend_position, ) # Apply '%' ticksuffix and max_y_tick_labels for multi-feature bar/donut y_kw = {} if y_label and "%" in y_label: y_kw["ticksuffix"] = "%" if max_y_tick_labels is not None and max_y_tick_labels > 0: y_kw["nticks"] = max_y_tick_labels if y_kw: fig.update_yaxes(**y_kw) # Apply class_visible to multi-feature bar/donut charts if class_visible is not None and isinstance(class_visible, dict): hidden = {name for name, vis in class_visible.items() if not vis} if hidden: for trace in fig.data: trace_name = trace.name or "" if (trace_name in hidden or any(trace_name.endswith(SPLIT_STR + h) for h in hidden) or any(trace_name.replace(SPLIT_STR, " ").strip() in hidden for _ in [0])): trace.visible = "legendonly" return (chart_df, _set_download_filename(fig)) # Standard single-region zonal stats path # Safety fallback: dissolve multi-feature FCs without a label (shouldn't # happen after auto-detection above, but guards against edge cases). if geo_type == "multi" and not feature_label: geometry = geometry.geometry() df = zonal_stats( ee_obj, geometry, band_names=band_names, reducer=reducer, scale=scale, crs=crs, transform=transform, tile_scale=tile_scale, area_format=area_format, x_axis_property=x_axis_property, date_format=date_format, include_masked_area=include_masked_area, ) df_full = df # Extract colors from class info (unless caller provided palette). # Build the color list to match actual DataFrame column order so that # multi-band thematic charts (e.g. Change + Land_Cover + Land_Use) # assign the correct color to each class. colors = palette if colors is None and class_info: # Build a lookup: column_name -> hex color color_lookup = {} for bn in obj_info["band_names"]: info = class_info.get(bn, {}) class_names = info.get("class_names", []) class_palette = info.get("class_palette", []) for i, name in enumerate(class_names): col_name = f"{bn}{SPLIT_STR}{name}" if len(obj_info["band_names"]) > 1 else name if i < len(class_palette): color_lookup[col_name] = class_palette[i] # Map each DataFrame column to its color (fall back to None) if color_lookup: colors = [color_lookup.get(col) for col in df.columns] # Pick chart type if obj_info["obj_type"] == "ImageCollection": if title is None: title = "Zonal Summary" fig = chart_time_series( df, colors=colors, chart_type=chart_type, title=title, x_label=( "Date" if x_axis_property == "system:time_start" else (x_axis_property.replace("_", " ").title() if x_axis_property != "year" else "Year") ), y_label=y_label, width=width, height=height, legend_position=legend_position, line_width=line_width, marker_size=marker_size, max_x_tick_labels=max_x_tick_labels, max_y_tick_labels=max_y_tick_labels, ) else: if title is None: title = "Class Distribution" if str(chart_type).lower().strip() == "donut": fig = chart_donut( df, colors=colors, title=title, width=width, height=height, legend_position=legend_position, ) else: fig = chart_bar( df, colors=colors, title=title, y_label=y_label, chart_type=chart_type, width=width, height=height, legend_position=legend_position, ) # For thematic data with first() reducer, map numeric class values # to class names on the y-axis. Other reducers (histogram, mean, etc.) # produce area/count values that need standard numeric y-axis ticks. _is_first_reducer = ( reducer is not None and hasattr(reducer, "getInfo") and "first" in str(reducer.serialize()).lower() ) if _is_first_reducer and obj_info["is_thematic"] and class_info: for bn in obj_info["band_names"]: info = class_info.get(bn, {}) vals = info.get("class_values", []) names = info.get("class_names", []) if vals and names and len(vals) == len(names): fig.update_yaxes( tickvals=vals, ticktext=names, ) break # Apply '%' ticksuffix and max_y_tick_labels for bar/donut charts # (time series charts handle this internally) if obj_info["obj_type"] != "ImageCollection": y_kw = {} if y_label and "%" in y_label: y_kw["ticksuffix"] = "%" if max_y_tick_labels is not None and max_y_tick_labels > 0: y_kw["nticks"] = max_y_tick_labels if y_kw: fig.update_yaxes(**y_kw) # Apply class_visible: toggle trace visibility in the figure. # Traces remain in the chart (user can click legend to re-enable). if class_visible is not None and isinstance(class_visible, dict): hidden = {name for name, vis in class_visible.items() if not vis} if hidden: for trace in fig.data: trace_name = trace.name or "" # Check if trace name matches a hidden class (with or without band prefix) if (trace_name in hidden or any(trace_name.endswith(SPLIT_STR + h) for h in hidden) or any(trace_name.replace(SPLIT_STR, " ").strip() in hidden for _ in [0])): trace.visible = "legendonly" return (df_full, _set_download_filename(fig))