Source code for pymars.plots

"""
Visualization tools for MARS models
===================================

Functions for plotting basis functions, model predictions, and diagnostics.
"""

import numpy as np
from typing import Optional, List, Tuple
import warnings

try:
    import matplotlib.pyplot as plt
    from mpl_toolkits.mplot3d import Axes3D
    HAS_MATPLOTLIB = True
except ImportError:
    HAS_MATPLOTLIB = False
    warnings.warn("matplotlib not available. Plotting functions will not work.")


[docs] def plot_univariate_effects(model, X: np.ndarray, feature_idx: int, n_points: int = 100, ax: Optional[plt.Axes] = None) -> plt.Axes: """ Plot the effect of a single feature on predictions Parameters ---------- model : MARS Fitted MARS model X : array Reference data for other features feature_idx : int Index of feature to plot n_points : int Number of points to evaluate ax : matplotlib axes, optional Axes to plot on Returns ------- ax : matplotlib axes """ if not HAS_MATPLOTLIB: raise ImportError("matplotlib required for plotting") if ax is None: fig, ax = plt.subplots(figsize=(8, 5)) # Get feature range x_min, x_max = X[:, feature_idx].min(), X[:, feature_idx].max() margin = (x_max - x_min) * 0.1 x_plot = np.linspace(x_min - margin, x_max + margin, n_points) # Use median values for other features X_plot = np.tile(np.median(X, axis=0), (n_points, 1)) X_plot[:, feature_idx] = x_plot # Predict y_plot = model.predict(X_plot) # Plot ax.plot(x_plot, y_plot, 'b-', linewidth=2, label='MARS prediction') ax.scatter(X[:, feature_idx], model.predict(X), alpha=0.3, s=20, c='gray', label='Training data') ax.set_xlabel(f'x{feature_idx}', fontsize=12) ax.set_ylabel('Prediction', fontsize=12) ax.set_title(f'Effect of Feature {feature_idx}', fontsize=14, fontweight='bold') ax.legend() ax.grid(True, alpha=0.3) return ax
[docs] def plot_bivariate_effect(model, X: np.ndarray, feature1: int, feature2: int, n_points: int = 50, plot_type: str = 'contour', ax: Optional[plt.Axes] = None) -> plt.Axes: """ Plot interaction effect between two features Parameters ---------- model : MARS Fitted MARS model X : array Reference data feature1, feature2 : int Indices of features to plot n_points : int Grid resolution plot_type : str 'contour' or 'surface' ax : matplotlib axes, optional Returns ------- ax : matplotlib axes """ if not HAS_MATPLOTLIB: raise ImportError("matplotlib required for plotting") # Create grid x1_min, x1_max = X[:, feature1].min(), X[:, feature1].max() x2_min, x2_max = X[:, feature2].min(), X[:, feature2].max() x1_grid = np.linspace(x1_min, x1_max, n_points) x2_grid = np.linspace(x2_min, x2_max, n_points) X1, X2 = np.meshgrid(x1_grid, x2_grid) # Prepare data X_plot = np.tile(np.median(X, axis=0), (n_points**2, 1)) X_plot[:, feature1] = X1.ravel() X_plot[:, feature2] = X2.ravel() # Predict y_pred = model.predict(X_plot).reshape(n_points, n_points) # Plot if plot_type == 'contour': if ax is None: fig, ax = plt.subplots(figsize=(8, 6)) contour = ax.contourf(X1, X2, y_pred, levels=20, cmap='viridis') plt.colorbar(contour, ax=ax, label='Prediction') ax.contour(X1, X2, y_pred, levels=10, colors='white', alpha=0.3, linewidths=0.5) ax.set_xlabel(f'x{feature1}', fontsize=12) ax.set_ylabel(f'x{feature2}', fontsize=12) ax.set_title(f'Interaction: x{feature1} × x{feature2}', fontsize=14, fontweight='bold') elif plot_type == 'surface': if ax is None: fig = plt.figure(figsize=(10, 8)) ax = fig.add_subplot(111, projection='3d') surf = ax.plot_surface(X1, X2, y_pred, cmap='viridis', alpha=0.8, edgecolor='none') plt.colorbar(surf, ax=ax, shrink=0.5, label='Prediction') ax.set_xlabel(f'x{feature1}', fontsize=10) ax.set_ylabel(f'x{feature2}', fontsize=10) ax.set_zlabel('Prediction', fontsize=10) ax.set_title(f'Interaction: x{feature1} × x{feature2}', fontsize=12, fontweight='bold') else: raise ValueError("plot_type must be 'contour' or 'surface'") return ax
[docs] def plot_basis_functions(model, X: np.ndarray, max_plot: int = 6, figsize: Tuple[int, int] = (12, 8)): """ Plot individual basis functions Parameters ---------- model : MARS Fitted MARS model X : array Data to evaluate on max_plot : int Maximum number of basis functions to plot figsize : tuple Figure size """ if not HAS_MATPLOTLIB: raise ImportError("matplotlib required for plotting") n_basis = min(len(model.basis_functions_), max_plot) n_cols = 3 n_rows = (n_basis + n_cols - 1) // n_cols fig, axes = plt.subplots(n_rows, n_cols, figsize=figsize) axes = axes.ravel() if n_basis > 1 else [axes] for i in range(n_basis): basis = model.basis_functions_[i] coef = model.coefficients_[i] # Evaluate basis values = basis.evaluate(X) ax = axes[i] ax.hist(values, bins=30, edgecolor='black', alpha=0.7) ax.set_title(f'Basis {i}: coef={coef:.3f}\n{basis}', fontsize=9) ax.set_xlabel('Value') ax.set_ylabel('Frequency') ax.grid(True, alpha=0.3) # Hide unused subplots for i in range(n_basis, len(axes)): axes[i].axis('off') plt.tight_layout() return fig
[docs] def plot_feature_importances(model, feature_names: Optional[List[str]] = None, figsize: Tuple[int, int] = (8, 5)) -> plt.Figure: """ Bar plot of feature importances Parameters ---------- model : MARS Fitted model feature_names : list of str, optional Names for features figsize : tuple Figure size Returns ------- fig : matplotlib figure """ if not HAS_MATPLOTLIB: raise ImportError("matplotlib required for plotting") importances = model.feature_importances_ n_features = len(importances) if feature_names is None: feature_names = [f'x{i}' for i in range(n_features)] # Sort by importance idx = np.argsort(importances)[::-1] fig, ax = plt.subplots(figsize=figsize) colors = ['green' if imp > 0.01 else 'lightgray' for imp in importances[idx]] ax.barh(range(n_features), importances[idx], color=colors, edgecolor='black') ax.set_yticks(range(n_features)) ax.set_yticklabels([feature_names[i] for i in idx]) ax.set_xlabel('Importance', fontsize=12) ax.set_title('Feature Importances', fontsize=14, fontweight='bold') ax.axvline(0.01, color='red', linestyle='--', linewidth=1, alpha=0.5, label='Threshold (0.01)') ax.legend() ax.grid(True, alpha=0.3, axis='x') plt.tight_layout() return fig
[docs] def plot_predictions(model, X: np.ndarray, y: np.ndarray, figsize: Tuple[int, int] = (8, 6)) -> plt.Figure: """ Scatter plot of predictions vs actual values Parameters ---------- model : MARS Fitted model X, y : arrays Data to predict on figsize : tuple Figure size Returns ------- fig : matplotlib figure """ if not HAS_MATPLOTLIB: raise ImportError("matplotlib required for plotting") y_pred = model.predict(X) residuals = y - y_pred fig, (ax1, ax2) = plt.subplots(1, 2, figsize=figsize) # Predictions vs actual ax1.scatter(y, y_pred, alpha=0.5, s=20) lims = [min(y.min(), y_pred.min()), max(y.max(), y_pred.max())] ax1.plot(lims, lims, 'r--', linewidth=2, label='Perfect prediction') ax1.set_xlabel('Actual', fontsize=12) ax1.set_ylabel('Predicted', fontsize=12) ax1.set_title('Predictions vs Actual', fontsize=13, fontweight='bold') ax1.legend() ax1.grid(True, alpha=0.3) # Residuals ax2.scatter(y_pred, residuals, alpha=0.5, s=20) ax2.axhline(0, color='red', linestyle='--', linewidth=2) ax2.set_xlabel('Predicted', fontsize=12) ax2.set_ylabel('Residuals', fontsize=12) ax2.set_title('Residual Plot', fontsize=13, fontweight='bold') ax2.grid(True, alpha=0.3) plt.tight_layout() return fig
[docs] def plot_anova_summary(model, figsize: Tuple[int, int] = (10, 6)) -> plt.Figure: """ Summary plot of ANOVA decomposition Parameters ---------- model : MARS Fitted model figsize : tuple Figure size Returns ------- fig : matplotlib figure """ if not HAS_MATPLOTLIB: raise ImportError("matplotlib required for plotting") anova = model.get_anova_decomposition() fig, (ax1, ax2) = plt.subplots(1, 2, figsize=figsize) # Count by degree degrees = sorted(anova.keys()) counts = [len(anova[d]) for d in degrees] ax1.bar(degrees, counts, edgecolor='black', color='steelblue', alpha=0.7) ax1.set_xlabel('Interaction Order', fontsize=12) ax1.set_ylabel('Number of Terms', fontsize=12) ax1.set_title('Terms by Interaction Order', fontsize=13, fontweight='bold') ax1.set_xticks(degrees) ax1.grid(True, alpha=0.3, axis='y') # Coefficient magnitudes by degree for degree in degrees: basis_list = anova[degree] indices = [model.basis_functions_.index(b) for b in basis_list] coefs = [abs(model.coefficients_[i]) for i in indices] if coefs: positions = [degree] * len(coefs) ax2.scatter(positions, coefs, s=50, alpha=0.6) ax2.set_xlabel('Interaction Order', fontsize=12) ax2.set_ylabel('|Coefficient|', fontsize=12) ax2.set_title('Coefficient Magnitudes', fontsize=13, fontweight='bold') ax2.set_xticks(degrees) ax2.grid(True, alpha=0.3) plt.tight_layout() return fig