Source code for pymars.interactions

"""
Advanced interaction analysis for MARS
======================================

Tools for analyzing and visualizing variable interactions in MARS models.
"""

import numpy as np
from typing import List, Dict, Tuple, Optional, Set
from itertools import combinations
from .basis import BasisFunction


[docs] class InteractionAnalyzer: """ Analyze interactions in fitted MARS model Parameters ---------- model : MARS Fitted MARS model """
[docs] def __init__(self, model): self.model = model if model.basis_functions_ is None: raise ValueError("Model must be fitted first")
[docs] def get_interaction_strength(self) -> Dict[Tuple[int, ...], float]: """ Calculate strength of each variable interaction Returns sum of absolute coefficients for each unique variable set. Returns ------- interactions : dict Maps variable tuple to interaction strength """ interactions = {} for basis, coef in zip(self.model.basis_functions_, self.model.coefficients_): if basis.degree == 0: continue # Skip constant var_tuple = tuple(sorted(basis.variables)) if var_tuple in interactions: interactions[var_tuple] += abs(coef) else: interactions[var_tuple] = abs(coef) return interactions
[docs] def get_pairwise_interactions(self) -> np.ndarray: """ Get matrix of pairwise interaction strengths Returns ------- matrix : array, shape (n_features, n_features) Symmetric matrix where [i,j] is strength of x_i * x_j interaction """ n_features = self.model.n_features_in_ matrix = np.zeros((n_features, n_features)) for basis, coef in zip(self.model.basis_functions_, self.model.coefficients_): if basis.degree == 2: # Two-way interaction vars = sorted(basis.variables) i, j = vars[0], vars[1] matrix[i, j] += abs(coef) matrix[j, i] += abs(coef) return matrix
[docs] def rank_interactions(self, top_k: int = 10) -> List[Tuple]: """ Rank interactions by strength Parameters ---------- top_k : int Number of top interactions to return Returns ------- ranked : list of tuples Each tuple: (variables, strength) """ interactions = self.get_interaction_strength() # Sort by strength ranked = sorted(interactions.items(), key=lambda x: x[1], reverse=True) return ranked[:top_k]
[docs] def find_pure_additive_effects(self) -> List[int]: """ Find variables that only appear in additive terms (degree=1) and never in interactions (degree > 1) Returns ------- variables : list of int Variable indices with only additive effects (sorted) """ # Collect variables that appear in interactions (degree > 1) interaction_vars = set() for basis in self.model.basis_functions_: if basis.degree > 1: interaction_vars.update(basis.variables) # Find variables that appear only in degree=1 terms # and are NOT in any interaction additive_only = set() for basis in self.model.basis_functions_: if basis.degree == 1: var = basis.variables[0] if var not in interaction_vars: additive_only.add(var) return sorted(additive_only)
[docs] def decompose_prediction(self, x: np.ndarray) -> Dict[str, float]: """ Decompose a single prediction into contributions Parameters ---------- x : array, shape (n_features,) Single input vector (in original, non-standardized scale) Returns ------- contributions : dict Maps component name to its contribution value Notes ----- If model.standardize=True, input x is automatically standardized to match the domain where basis functions were trained. """ if x.ndim == 1: x = x.reshape(1, -1) # Standardize input if model was trained with standardization if hasattr(self.model, 'standardize') and self.model.standardize: x_eval = (x - self.model._x_mean) / self.model._x_std else: x_eval = x contributions = {'constant': 0.0} for basis, coef in zip(self.model.basis_functions_, self.model.coefficients_): value = basis.evaluate(x_eval)[0] contrib = coef * value if basis.degree == 0: contributions['constant'] = contrib elif basis.degree == 1: var = basis.variables[0] key = f'x{var}' contributions[key] = contributions.get(key, 0.0) + contrib else: var_str = 'x' + '*x'.join(map(str, sorted(basis.variables))) key = f'interaction_{var_str}' contributions[key] = contributions.get(key, 0.0) + contrib return contributions
[docs] def interaction_test(self, var1: int, var2: int, X: np.ndarray, y: np.ndarray) -> float: """ Test if interaction between two variables improves fit Compares model with and without the specific interaction. Parameters ---------- var1, var2 : int Variable indices X, y : arrays Test data Returns ------- improvement : float R² improvement from including interaction """ from .mars import MARS # Fit additive model model_add = MARS(max_terms=20, max_degree=1, verbose=False) model_add.fit(X, y) r2_add = model_add.score(X, y) # Fit with interactions model_int = MARS(max_terms=20, max_degree=2, verbose=False) model_int.fit(X, y) r2_int = model_int.score(X, y) # Check if the specific interaction was selected has_interaction = False for basis in model_int.basis_functions_: if basis.degree == 2: vars_in_basis = set(basis.variables) if vars_in_basis == {var1, var2}: has_interaction = True break if has_interaction: return r2_int - r2_add else: return 0.0
[docs] def hierarchical_interaction_map(self) -> Dict[int, Set[Tuple[int, ...]]]: """ Create hierarchical map of interactions Returns ------- hierarchy : dict Maps degree to set of variable tuples at that degree """ hierarchy = {} for basis in self.model.basis_functions_: degree = basis.degree if degree == 0: continue var_tuple = tuple(sorted(basis.variables)) if degree not in hierarchy: hierarchy[degree] = set() hierarchy[degree].add(var_tuple) return hierarchy
[docs] def compute_h_statistic(self, var1: int, var2: int, X: np.ndarray) -> float: """ Compute Friedman's H-statistic for interaction strength H measures the proportion of variance in the joint effect that cannot be explained by the sum of main effects. Parameters ---------- var1, var2 : int Variable indices X : array Input data Returns ------- h_stat : float H-statistic (0 = no interaction, 1 = pure interaction) """ n_samples = X.shape[0] # Get predictions varying both variables f_12 = self.model.predict(X) # Get predictions varying only var1 (fix var2 at median) X_1 = X.copy() X_1[:, var2] = np.median(X[:, var2]) f_1 = self.model.predict(X_1) # Get predictions varying only var2 (fix var1 at median) X_2 = X.copy() X_2[:, var1] = np.median(X[:, var1]) f_2 = self.model.predict(X_2) # Get baseline (both fixed) X_0 = X.copy() X_0[:, var1] = np.median(X[:, var1]) X_0[:, var2] = np.median(X[:, var2]) f_0 = self.model.predict(X_0) # Compute H-statistic interaction_effect = f_12 - f_1 - f_2 + f_0 total_effect = f_12 - f_0 h_numerator = np.sum(interaction_effect ** 2) h_denominator = np.sum(total_effect ** 2) if h_denominator < 1e-10: return 0.0 h_stat = h_numerator / h_denominator return np.clip(h_stat, 0.0, 1.0)
[docs] def analyze_interactions_full(model, X: np.ndarray, threshold: float = 0.01) -> Dict: """ Comprehensive interaction analysis Parameters ---------- model : MARS Fitted model X : array Input data for analysis threshold : float Minimum strength threshold Returns ------- analysis : dict Complete interaction analysis results """ analyzer = InteractionAnalyzer(model) # Get all interactions interactions = analyzer.get_interaction_strength() # Filter by threshold significant = {k: v for k, v in interactions.items() if v >= threshold} # Rank them ranked = analyzer.rank_interactions(top_k=20) # Find additive effects additive = analyzer.find_pure_additive_effects() # Get pairwise matrix pairwise = analyzer.get_pairwise_interactions() # Hierarchy hierarchy = analyzer.hierarchical_interaction_map() analysis = { 'all_interactions': interactions, 'significant_interactions': significant, 'top_interactions': ranked, 'pure_additive_vars': additive, 'pairwise_matrix': pairwise, 'hierarchy': hierarchy, 'max_degree': max(hierarchy.keys()) if hierarchy else 0, 'n_interactions': len(significant), 'n_additive': len(additive) } return analysis