Source code for growthcurves.plot

"""
Plotting functions for growth curve analysis using Plotly.

This module provides modular functions for creating and annotating growth curve plots.
"""

from typing import Any, Dict, Optional, Tuple

import numpy as np
import plotly.graph_objects as go

from .models import (
    evaluate_parametric_model,
    get_all_parametric_models,
    spline_from_params,
)


[docs] def create_base_plot( t: np.ndarray, N: np.ndarray, scale: str = "linear", xlabel: str = "Time (hours)", ylabel: Optional[str] = None, marker_size: int = 5, marker_opacity: float = 0.3, marker_color: str = "gray", ) -> go.Figure: """ Create a base plot with raw N points. Parameters ---------- t : numpy.ndarray Time points N : numpy.ndarray OD measurements or other growth N scale : str, optional 'linear' or 'log' scale for y-axis (default: 'linear') xlabel : str, optional X-axis label (default: 'Time (hours)') ylabel : str, optional Y-axis label. If None, automatically set based on scale marker_size : int, optional Size of N point markers (default: 5) marker_opacity : float, optional Opacity of N point markers (default: 0.3) marker_color : str, optional Color of N point markers (default: 'gray') Returns ------- plotly.graph_objects.Figure Plotly figure object with raw N """ # Convert to numpy arrays t = np.asarray(t, dtype=float) N = np.asarray(N, dtype=float) # Filter out non-positive and non-finite values for valid plotting mask = np.isfinite(t) & np.isfinite(N) & (N > 0) t = t[mask] N = N[mask] # Determine y-axis N based on scale if scale == "log": y_data = np.log(N) if ylabel is None: ylabel = "ln(OD)" else: y_data = N if ylabel is None: ylabel = "OD" # Create figure fig = go.Figure() # Add raw N trace fig.add_trace( go.Scatter( x=t, y=y_data, mode="markers", name="Data", marker=dict(size=marker_size, opacity=marker_opacity, color=marker_color), showlegend=False, ) ) # Update layout # For linear scale, set y-axis to start at 0; for log scale, auto-range yaxis_config = dict(visible=True, showline=True) if scale == "linear": yaxis_config["range"] = [0, None] fig.update_layout( xaxis_title=xlabel, yaxis_title=ylabel, hovermode="closest", template="plotly_white", showlegend=False, xaxis=dict(range=[0, None]), # Start x-axis at 0 to remove gap yaxis=yaxis_config, ) return fig
[docs] def add_exponential_phase( fig: go.Figure, exp_start: float, exp_end: float, color: str = "lightgreen", opacity: float = 0.25, name: str = "Exponential phase", row: Optional[int] = None, col: Optional[int] = None, ) -> go.Figure: """ Add shaded region for exponential growth phase. Parameters ---------- fig : plotly.graph_objects.Figure Plotly figure to annotate exp_start : float Start time of exponential phase exp_end : float End time of exponential phase color : str, optional Color for shaded region (default: 'lightgreen') opacity : float, optional Opacity of shaded region (default: 0.25) name : str, optional Legend name (default: 'Exponential phase') row : int, optional Subplot row (for subplots) col : int, optional Subplot column (for subplots) Returns ------- plotly.graph_objects.Figure Updated figure with exponential phase shading """ if exp_start is None or exp_end is None: return fig if not np.isfinite(exp_start) or not np.isfinite(exp_end): return fig # Add vertical rectangle for exponential phase fig.add_vrect( x0=exp_start, x1=exp_end, fillcolor=color, opacity=opacity, layer="below", line_width=0, row=row, col=col, ) return fig
[docs] def add_fitted_curve( fig: go.Figure, time_fit: np.ndarray, N_fit: np.ndarray, name: str = "Fitted curve", color: str = "blue", line_width: int = 5, window_start: Optional[float] = None, window_end: Optional[float] = None, scale: Optional[str] = "linear", row: Optional[int] = None, col: Optional[int] = None, ) -> go.Figure: """ Add fitted curve to the plot, optionally constrained to a window. Parameters ---------- fig : plotly.graph_objects.Figure Plotly figure to annotate time_fit : numpy.ndarray Time points for fitted curve N_fit : numpy.ndarray Fitted y values name : str, optional Legend name for fitted curve (default: 'Fitted curve') color : str, optional Color of fitted curve (default: 'blue') line_width : int, optional Width of fitted curve line (default: 5) window_start : float, optional Start of fitting window (if specified, only show curve in this range) window_end : float, optional End of fitting window (if specified, only show curve in this range) scale : str, optional 'linear' or 'log' - determines y-axis transformation (default: 'linear') row : int, optional Subplot row (for subplots) col : int, optional Subplot column (for subplots) Returns ------- plotly.graph_objects.Figure Updated figure with fitted curve """ if time_fit is None or N_fit is None: return fig # Convert to numpy arrays time_fit = np.asarray(time_fit, dtype=float) N_fit = np.asarray(N_fit, dtype=float) # Filter to window if specified if window_start is not None and window_end is not None: mask = (time_fit >= window_start) & (time_fit <= window_end) time_fit = time_fit[mask] N_fit = N_fit[mask] # Transform y-values based on scale if scale == "log": N_fit = np.log(N_fit) # Add fitted curve fig.add_trace( go.Scatter( x=time_fit, y=N_fit, mode="lines", name=name, line=dict(color=color, width=line_width), showlegend=False, ), row=row, col=col, ) return fig
[docs] def add_od_max_line( fig: go.Figure, od_max: float, scale: str = "linear", line_color: str = "black", line_dash: str = "dot", line_width: float = 2, line_opacity: float = 0.5, name: str = "ODmax", row: Optional[int] = None, col: Optional[int] = None, ) -> go.Figure: """ Add horizontal line at maximum OD value. Parameters ---------- fig : plotly.graph_objects.Figure Plotly figure to annotate od_max : float Maximum OD value scale : str, optional 'linear' or 'log' - determines y-axis transformation (default: 'linear') line_color : str, optional Color of horizontal line (default: 'red') line_dash : str, optional Dash style for horizontal line (default: 'dash') line_width : float, optional Width of horizontal line (default: 1) line_opacity : float, optional Opacity of horizontal line (default: 0.5) name : str, optional Legend name (default: 'ODmax') row : int, optional Subplot row (for subplots) col : int, optional Subplot column (for subplots) Returns ------- plotly.graph_objects.Figure Updated figure with od_max horizontal line """ if od_max is None: return fig if not np.isfinite(od_max): return fig # Transform y-value based on scale y_val = np.log(od_max) if scale == "log" else od_max # Add horizontal line at od_max fig.add_hline( y=y_val, line_color=line_color, line_dash=line_dash, line_width=line_width, opacity=line_opacity, row=row, col=col, ) return fig
[docs] def add_N0_line( fig: go.Figure, N0: float, scale: str = "linear", line_color: str = "gray", line_dash: str = "dot", line_width: float = 2, line_opacity: float = 0.5, name: str = "N0", row: Optional[int] = None, col: Optional[int] = None, ) -> go.Figure: """ Add horizontal line at initial OD value (N0). Parameters ---------- fig : gplotly.graph_objectso.Figure Plotly figure to annotate N0 : float Initial OD value scale : str, optional 'linear' or 'log' - determines y-axis transformation (default: 'linear') line_color : str, optional Color of horizontal line (default: 'gray') line_dash : str, optional Dash style for horizontal line (default: 'dot') line_width : float, optional Width of horizontal line (default: 2) line_opacity : float, optional Opacity of horizontal line (default: 0.5) name : str, optional Legend name (default: 'N0') row : int, optional Subplot row (for subplots) col : int, optional Subplot column (for subplots) Returns ------- plotly.graph_objects.Figure Updated figure with N0 horizontal line """ if N0 is None: return fig if not np.isfinite(N0): return fig # Transform y-value based on scale y_val = np.log(N0) if scale == "log" else N0 # Add horizontal line at N0 fig.add_hline( y=y_val, line_color=line_color, line_dash=line_dash, line_width=line_width, opacity=line_opacity, row=row, col=col, ) return fig
[docs] def prepare_fitted_curve( fitted_model: Dict[str, Any], n_points: int = 200 ) -> Optional[Tuple[np.ndarray, np.ndarray]]: """ Convert a fitted model dictionary to curve data for plotting. Parameters ---------- fitted_model : dict Fit result dictionary from fit_parametric() or fit_non_parametric() n_points : int, optional Number of points to generate for the curve (default: 200) Returns ------- tuple of (np.ndarray, np.ndarray) or None (time_points, od_values) ready for plotting, or None if invalid """ if fitted_model is None: return None model_type = fitted_model.get("model_type") params = fitted_model.get("params") if model_type is None or params is None: return None # Extract window boundaries if "fit_t_min" in params and "fit_t_max" in params: window_start = params["fit_t_min"] window_end = params["fit_t_max"] else: window_start = fitted_model.get("window_start") window_end = fitted_model.get("window_end") if window_start is None or window_end is None: return None # Generate time points time_fit = np.linspace(window_start, window_end, n_points) # Evaluate model if model_type in get_all_parametric_models(): od_fit = evaluate_parametric_model(time_fit, model_type, params) elif model_type == "spline": spline = spline_from_params(params) od_fit = np.exp(spline(time_fit)) elif model_type == "sliding_window": slope = params["slope"] intercept = params["intercept"] od_fit = np.exp(slope * time_fit + intercept) else: return None return (time_fit, od_fit)
[docs] def prepare_tangent_line( umax: float, time_umax: float, od_umax: float, fig: go.Figure, scale: Optional[str] = "linear", n_points: Optional[int] = 100, ) -> Optional[Tuple[np.ndarray, np.ndarray]]: """ Calculate tangent line at maximum growth rate point. Parameters ---------- umax : float Maximum growth rate (μ_max) time_umax : float Time at maximum growth rate od_umax : float OD value at maximum growth rate fig : plotly.graph_objects.Figure Figure to extract data range from scale : str, optional 'linear' or 'log' for determining data range (default: 'linear') n_points : int, optional Number of points to generate for tangent line (default: 100) Returns ------- tuple of (numpy.ndarray, numpy.ndarray) or None (time_points, od_values) for tangent line, or None if invalid """ if not np.isfinite(umax) or not np.isfinite(time_umax) or not np.isfinite(od_umax): return None # Extract y-values from figure to determine baseline and plateau OD all_y_values = [] for trace in fig.data: if trace.y is not None and len(trace.y) > 0: valid_y = [y for y in trace.y if y is not None and np.isfinite(y)] if valid_y: all_y_values.extend(valid_y) if len(all_y_values) == 0: return None if scale == "log": baseline_od = np.exp(min(all_y_values)) plateau_od = np.exp(max(all_y_values)) else: baseline_od = min(all_y_values) plateau_od = max(all_y_values) # Ensure baseline < od_umax < plateau (with safety margins) baseline_od = min(baseline_od, od_umax * 0.95) plateau_od = max(plateau_od, od_umax * 1.05) if baseline_od <= 0 or plateau_od <= 0 or od_umax <= 0: return None # Calculate tangent intersections # Tangent equation: OD(t) = od_umax * exp(umax * (t - time_umax)) t_start = time_umax + np.log(baseline_od / od_umax) / umax t_end = time_umax + np.log(plateau_od / od_umax) / umax # Generate tangent line points t_tangent = np.linspace(t_start, t_end, n_points) od_tangent = od_umax * np.exp(umax * (t_tangent - time_umax)) return (t_tangent, od_tangent)
[docs] def annotate_plot( fig: go.Figure, fit_result: Optional[Dict[str, Any]] = None, stats: Optional[Dict[str, Any]] = None, show_fitted_curve: bool = True, show_phase_boundaries: bool = True, show_crosshairs: bool = True, show_od_max_line: bool = True, show_n0_line: bool = True, show_umax_marker: bool = True, show_tangent: bool = True, scale: str = "linear", fitted_curve_color: str = "#8dcde0", fitted_curve_width: int = 5, row: Optional[int] = None, col: Optional[int] = None, ) -> go.Figure: """ Add annotations to a growth curve plot. Parameters ---------- fig : plotly.graph_objects.Figure Plotly figure to annotate fit_result : dict, optional Fit result dictionary from fit_parametric() or fit_non_parametric() stats : dict, optional Statistics dictionary from extract_stats() show_fitted_curve : bool, optional Whether to show the fitted curve (default: True) show_phase_boundaries : bool, optional Whether to show exponential phase boundaries (default: True) show_crosshairs : bool, optional Whether to show crosshairs to umax point (default: True) show_od_max_line : bool, optional Whether to show horizontal line at maximum OD (default: True) show_n0_line : bool, optional Whether to show horizontal line at initial OD (default: True) show_umax_marker : bool, optional Whether to show marker at umax point (default: True) show_tangent : bool, optional Whether to show tangent line at umax (default: True) scale : str, optional 'linear' or 'log' for y-axis scale (default: 'linear') fitted_curve_color : str, optional Color of the fitted model curve (default: '#8dcde0') fitted_curve_width : int, optional Line width of the fitted model curve (default: 5) row : int, optional Subplot row for subplots col : int, optional Subplot column for subplots Returns ------- plotly.graph_objects.Figure Updated figure with annotations """ # Add fitted curve if show_fitted_curve and fit_result is not None: fitted_curve = prepare_fitted_curve(fit_result) if fitted_curve is not None: time_fit, od_fit = fitted_curve fig = add_fitted_curve( fig, time_fit, od_fit, name="Fitted curve", color=fitted_curve_color, line_width=fitted_curve_width, scale=scale, row=row, col=col, ) # Add exponential phase shading if show_phase_boundaries and stats is not None: exp_start = stats.get("exp_phase_start") exp_end = stats.get("exp_phase_end") if exp_start is not None and exp_end is not None: fig = add_exponential_phase(fig, exp_start, exp_end, row=row, col=col) # Add crosshairs to umax point if show_crosshairs and stats is not None: time_umax = stats.get("time_at_umax") od_umax = stats.get("od_at_umax") if time_umax is not None and od_umax is not None: if np.isfinite(time_umax) and np.isfinite(od_umax): y_val = np.log(od_umax) if scale == "log" else od_umax # Determine bottom of vertical line if scale == "log": y_min_vals = [] for trace in fig.data: if trace.y is not None and len(trace.y) > 0: valid_y = [y for y in trace.y if np.isfinite(y)] if valid_y: y_min_vals.append(min(valid_y)) y_bottom = min(y_min_vals) if y_min_vals else y_val else: y_bottom = 0 # Vertical line fig.add_shape( type="line", x0=time_umax, y0=y_bottom, x1=time_umax, y1=y_val, line=dict(color="black", dash="dot", width=2), opacity=0.5, row=row, col=col, ) # Horizontal line fig.add_shape( type="line", x0=0, y0=y_val, x1=time_umax, y1=y_val, line=dict(color="black", dash="dot", width=2), opacity=0.5, row=row, col=col, ) # Add od_max horizontal line if show_od_max_line and stats is not None: od_max = stats.get("max_od") if od_max is not None: fig = add_od_max_line(fig, od_max, scale=scale, row=row, col=col) # Add N0 horizontal line if show_n0_line and stats is not None: n0 = stats.get("N0") if n0 is not None: fig = add_N0_line(fig, n0, scale=scale, row=row, col=col) # Add umax marker point if show_umax_marker and stats is not None: time_umax = stats.get("time_at_umax") od_umax = stats.get("od_at_umax") if time_umax is not None and od_umax is not None: if np.isfinite(time_umax) and np.isfinite(od_umax): y_val = np.log(od_umax) if scale == "log" else od_umax fig.add_trace( go.Scatter( x=[time_umax], y=[y_val], mode="markers", marker=dict(size=15, color="#66BB6A", symbol="circle"), showlegend=False, ), row=row, col=col, ) # Add tangent line at umax if show_tangent and stats is not None: umax = stats.get("mu_max") time_umax = stats.get("time_at_umax") od_umax = stats.get("od_at_umax") if umax is not None and time_umax is not None and od_umax is not None: tangent_data = prepare_tangent_line(umax, time_umax, od_umax, fig, scale) if tangent_data is not None: time_vals, od_vals = tangent_data y_vals = np.log(od_vals) if scale == "log" else od_vals fig.add_trace( go.Scatter( x=time_vals, y=y_vals, mode="lines", line=dict(color="green", width=2, dash="dash"), name="umax tangent", showlegend=False, hovertemplate="Tangent line at μmax<extra></extra>", ), row=row, col=col, ) return fig
[docs] def plot_derivative_metric( t: np.ndarray, N: np.ndarray, metric: str = "mu", fit_result: Optional[Dict[str, Any]] = None, sg_window: int = 11, sg_poly: int = 2, phase_boundaries: Optional[Tuple[float, float]] = None, title: Optional[str] = None, raw_line_width: int = 1, smoothed_line_width: int = 2, fitted_line_width: int = 2, ) -> go.Figure: """ Plot either dN/dt or μ (specific growth rate) for a growth curve. This function generates up to three traces: 1. Raw N metric (light grey) 2. Smoothed N metric (main trace, pink/red) 3. Model fit metric (dashed blue line, if fit_result provided) Parameters ---------- t : numpy.ndarray Time array N : numpy.ndarray OD600 values (baseline-corrected) metric : str, optional Either "dndt" for dN/dt or "mu" for μ (default: "mu") fit_result : dict, optional Fit result dictionary from fit_parametric() or fit_non_parametric(). If provided, the fitted model's derivative will be shown. Should contain 'model_type' and 'params' keys. sg_window : int, optional Savitzky-Golay window size for smoothing (default: 11) sg_poly : int, optional Savitzky-Golay polynomial order for smoothing (default: 2) phase_boundaries : tuple of (float, float), optional Tuple of (exp_start, exp_end) for exponential phase boundaries. If provided, adds shading for the phase. title : str, optional Plot title. If None, automatically generated based on metric. raw_line_width : int, optional Line width of the raw metric trace (default: 1) smoothed_line_width : int, optional Line width of the smoothed metric trace (default: 2) fitted_line_width : int, optional Line width of the fitted model metric trace (default: 2) Returns ------- plotly.graph_objects.Figure Plotly figure with derivative metric plot Examples -------- >>> import numpy as np >>> from growthcurves import plot_derivative_metric, fit_non_parametric >>> >>> # Generate some example N >>> t = np.linspace(0, 24, 100) >>> N = 0.05 * np.exp(0.5 * t) / (1 + (0.05/2.0) * (np.exp(0.5 * t) - 1)) >>> >>> # Plot specific growth rate without fit >>> fig = plot_derivative_metric(t, N, metric="mu") >>> >>> # Plot with fitted model >>> fit_result = fit_non_parametric(t, N, umax_method="spline") >>> fig = plot_derivative_metric( ... t, N, ... metric="mu", ... fit_result=fit_result, ... phase_boundaries=(5, 15) ... ) """ from .inference import ( compute_first_derivative, compute_instantaneous_mu, compute_sliding_window_growth_rate, smooth, ) # Validate metric if metric not in ["dndt", "mu"]: raise ValueError(f"metric must be 'dndt' or 'mu', got '{metric}'") # Convert to numpy arrays t = np.asarray(t, dtype=float) N = np.asarray(N, dtype=float) # Remove non-finite and non-positive values (needed for mu calculation) mask = np.isfinite(t) & np.isfinite(N) & (N > 0) t = t[mask] N = N[mask] if len(t) < 3: return go.Figure() # Store full t range for x-axis x_range = [float(t.min()), float(t.max())] # Step 1: Calculate metric on raw N if metric == "dndt": t_metric_raw, metric_raw = compute_first_derivative(t, N) metric_label = "dN/dt" y_axis_title = "dN/dt" plot_title = title or "First Derivative (dN/dt)" else: # mu t_metric_raw, metric_raw = compute_instantaneous_mu(t, N) metric_label = "μ" y_axis_title = "μ (h⁻¹)" plot_title = title or "Specific Growth Rate (μ)" # Step 2: Smooth the N y_smooth = smooth(N, sg_window, sg_poly) # Step 3: Calculate metric on smoothed N if metric == "dndt": t_metric_smooth, metric_smooth = compute_first_derivative(t, y_smooth) else: # mu t_metric_smooth, metric_smooth = compute_instantaneous_mu(t, y_smooth) # Create figure fig = go.Figure() template = f"Time=%{{x:.2f}}<br>{metric_label} (raw)=%{{y:.4f}}<extra></extra>" # Plot raw metric (light grey) fig.add_trace( go.Scatter( x=t_metric_raw, y=metric_raw, mode="lines", line=dict(width=raw_line_width, color="lightgrey"), hovertemplate=template, showlegend=False, name="Raw", ) ) template = ( f"Time=%{{x:.2f}}<br>{metric_label} (smoothed)=%{{y:.4f}}<extra></extra>", ) # Plot smoothed metric (pink/red) fig.add_trace( go.Scatter( x=t_metric_smooth, y=metric_smooth, mode="lines", line=dict(width=smoothed_line_width, color="#FF6692"), hovertemplate=template, showlegend=False, name="Smoothed", ) ) # Step 4 & 5: Generate model metric and plot (if fit_result provided) if fit_result is not None: model_type = fit_result.get("model_type", "") params = fit_result.get("params", {}) metric_model = None t_model = None # Get the fitted N range fit_t_min = params.get("fit_t_min") fit_t_max = params.get("fit_t_max") # Filter to fitted range if available if fit_t_min is not None and fit_t_max is not None: fit_mask = (t >= fit_t_min) & (t <= fit_t_max) t_model = t[fit_mask] y_model_raw = N[fit_mask] y_model_smooth = y_smooth[fit_mask] else: # Use full range if fit bounds not available t_model = t y_model_raw = N y_model_smooth = y_smooth if len(t_model) >= 2: if model_type == "sliding_window": # For sliding window, calculate from raw N (as growthcurves does) window_points = params.get("window_points", 15) if metric == "dndt": # For dN/dt, we need to smooth first then compute derivative _, metric_model = compute_first_derivative(t_model, y_model_smooth) else: # mu # For μ, use sliding window on raw N _, metric_model = compute_sliding_window_growth_rate( t_model, y_model_raw, window_points=window_points ) elif model_type in get_all_parametric_models(): # For parametric models, compute metric from the model # Evaluate the model on fitted range y_model = evaluate_parametric_model(t_model, model_type, params) # Compute metric from model if metric == "dndt": _, metric_model = compute_first_derivative(t_model, y_model) else: # mu _, metric_model = compute_instantaneous_mu(t_model, y_model) elif model_type == "spline": # For spline model, reconstruct the spline and evaluate try: spline = spline_from_params(params) if metric == "dndt": # Spline is fitted to log(y), so exp(spline(t)) gives y y_log_model = spline(t_model) y_model = np.exp(y_log_model) _, metric_model = compute_first_derivative(t_model, y_model) else: # mu # μ = d(ln(y))/dt, which is the derivative of the spline metric_model = spline.derivative()(t_model) except Exception: # If spline reconstruction fails, skip model trace pass # Plot model metric if available if ( metric_model is not None and t_model is not None and np.isfinite(metric_model).any() ): template = ( f"Time=%{{x:.2f}}<br>{metric_label} (fitted)=%{{y:.4f}}<extra></extra>" ) fig.add_trace( go.Scatter( x=t_model, y=metric_model, mode="lines", line=dict(width=fitted_line_width, color="#8dcde0"), hovertemplate=template, showlegend=False, name="Fitted", ) ) # Add phase boundary annotations if provided if phase_boundaries is not None and len(phase_boundaries) == 2: exp_start, exp_end = phase_boundaries if exp_start is not None and exp_end is not None: if np.isfinite(exp_start) and np.isfinite(exp_end): fig = add_exponential_phase(fig, exp_start, exp_end) # Update layout fig.update_layout( title=plot_title, height=400, showlegend=False, plot_bgcolor="white", paper_bgcolor="white", margin=dict(l=40, r=20, t=60, b=40), template="plotly_white", ) fig.update_xaxes(showgrid=False, title="Time (hours)", range=x_range) fig.update_yaxes(showgrid=False, title=y_axis_title) return fig
[docs] def plot_growth_stats_comparison( stats_dict: Dict[str, Dict[str, Any]], title: str = "Growth Statistics Comparison", metric_order: Optional[list] = None, ) -> go.Figure: """ Create a multi-panel bar chart comparing growth statistics across methods. Parameters ---------- stats_dict : dict Dictionary mapping method names to their growth statistics dictionaries. Each stats dict should contain keys like 'mu_max', 'doubling_time', etc. title : str, optional Overall title for the figure (default: "Growth Statistics Comparison") metric_order : list, optional List of metric keys to plot in specific order. If None, uses default order. Returns ------- plotly.graph_objects.Figure Plotly figure with subplots showing each metric comparison Examples -------- >>> # Compare multiple fitting methods >>> stats_dict = { ... 'logistic': stats_logistic, ... 'gompertz': stats_gompertz, ... 'spline': stats_spline ... } >>> fig = plot_growth_stats_comparison( ... stats_dict, ... title="Model Comparison" ... ) >>> fig.show() """ import pandas as pd from plotly.subplots import make_subplots df = pd.DataFrame(stats_dict).T default_metrics = [ "mu_max", "intrinsic_growth_rate", "doubling_time", "time_at_umax", "exp_phase_start", "exp_phase_end", "model_rmse", ] metrics = metric_order or [m for m in default_metrics if m in df.columns] numeric_df = df.copy() for m in metrics: numeric_df[m] = pd.to_numeric(numeric_df[m], errors="coerce") n_metrics = len(metrics) n_cols = 3 n_rows = int(np.ceil(n_metrics / n_cols)) fig = make_subplots( rows=n_rows, cols=n_cols, subplot_titles=[m.replace("_", " ").title() for m in metrics], horizontal_spacing=0.08, vertical_spacing=0.15, ) method_names = list(numeric_df.index) for i, metric in enumerate(metrics): row = i // n_cols + 1 col = i % n_cols + 1 fig.add_trace( go.Bar( x=method_names, y=numeric_df[metric].tolist(), showlegend=False, marker=dict(line=dict(color="black", width=1)), ), row=row, col=col, ) fig.update_layout( title=title, height=max(420, 320 * n_rows), width=1200, bargap=0.25, template="plotly_white", ) return fig