Skip to content

API Reference

ModelSpec dataclass

Specification for a model registered in TanML.

Attributes:

Name Type Description
task Task

Either 'classification' or 'regression'.

import_path str

Fully qualified path to the estimator class.

defaults dict[str, Any]

Default hyperparameters for the estimator.

ui_schema dict[str, tuple[str, tuple[Any, ...] | None, str | None]]

Metadata for rendering parameter inputs in the UI.

aliases dict[str, str]

Mapping of UI parameter names to estimator-specific names.

Source code in tanml/models/registry.py
@dataclass(frozen=True)
class ModelSpec:
    """
    Specification for a model registered in TanML.

    Attributes:
        task: Either 'classification' or 'regression'.
        import_path: Fully qualified path to the estimator class.
        defaults: Default hyperparameters for the estimator.
        ui_schema: Metadata for rendering parameter inputs in the UI.
        aliases: Mapping of UI parameter names to estimator-specific names.
    """
    task: Task
    import_path: str  # e.g., "sklearn.ensemble.RandomForestClassifier"
    defaults: dict[str, Any] = field(default_factory=dict)
    # UI schema: param -> (type, choices_or_None, help_or_None)
    ui_schema: dict[str, tuple[str, tuple[Any, ...] | None, str | None]] = field(
        default_factory=dict
    )
    aliases: dict[str, str] = field(default_factory=dict)  # optional param alias map

build_estimator(library, algo, params=None)

Factory function to create a scikit-learn compatible estimator instance.

This is the primary entry point for programmatic model creation in TanML. It resolves the requested model from the internal registry, applies sane defaults for the specific task, and overrides them with any user-provided parameters.

Supported Libraries
  • sklearn: LogisticRegression, RandomForestClassifier, SVC, OLS, etc.
  • xgboost: XGBClassifier, XGBRegressor.
  • lightgbm: LGBMClassifier, LGBMRegressor.
  • catboost: CatBoostClassifier, CatBoostRegressor.
  • statsmodels: Logit, OLS.

Parameters:

Name Type Description Default
library str

The library containing the model (e.g., 'sklearn', 'xgboost').

required
algo str

The specific algorithm class name (e.g., 'RandomForestClassifier').

required
params dict[str, Any] | None

Optional dictionary of hyperparameters to override the pre-configured TanML defaults.

None

Returns:

Type Description

An initialized estimator object. For most libraries, this is a

standard scikit-learn compatible object. For statsmodels, it

is a wrapped instance that supports the .fit() API.

Raises:

Type Description
KeyError

If the library/algo combination is not in the registry.

ImportError

If the underlying library (e.g., xgboost) is not installed.

Example

Create a Random Forest with custom depth:

from tanml.models.registry import build_estimator model = build_estimator( ... library="sklearn", ... algo="RandomForestClassifier", ... params={"max_depth": 5} ... ) print(model.max_depth) 5

Source code in tanml/models/registry.py
def build_estimator(library: str, algo: str, params: dict[str, Any] | None = None):
    """
    Factory function to create a scikit-learn compatible estimator instance.

    This is the primary entry point for programmatic model creation in TanML.
    It resolves the requested model from the internal registry, applies
    sane defaults for the specific task, and overrides them with any
    user-provided parameters.

    Supported Libraries:
        - `sklearn`: LogisticRegression, RandomForestClassifier, SVC, OLS, etc.
        - `xgboost`: XGBClassifier, XGBRegressor.
        - `lightgbm`: LGBMClassifier, LGBMRegressor.
        - `catboost`: CatBoostClassifier, CatBoostRegressor.
        - `statsmodels`: Logit, OLS.

    Args:
        library: The library containing the model (e.g., 'sklearn', 'xgboost').
        algo: The specific algorithm class name (e.g., 'RandomForestClassifier').
        params: Optional dictionary of hyperparameters to override the
            pre-configured TanML defaults.

    Returns:
        An initialized estimator object. For most libraries, this is a
        standard scikit-learn compatible object. For `statsmodels`, it
        is a wrapped instance that supports the `.fit()` API.

    Raises:
        KeyError: If the library/algo combination is not in the registry.
        ImportError: If the underlying library (e.g., xgboost) is not installed.

    Example:
        Create a Random Forest with custom depth:
        >>> from tanml.models.registry import build_estimator
        >>> model = build_estimator(
        ...     library="sklearn",
        ...     algo="RandomForestClassifier",
        ...     params={"max_depth": 5}
        ... )
        >>> print(model.max_depth)
        5
    """
    spec = get_spec(library, algo)
    Cls = _lazy_import(spec.import_path)
    kwargs = dict(spec.defaults)
    if params:
        canon = {}
        for k, v in params.items():
            k2 = spec.aliases.get(k, k)
            canon[k2] = v
        kwargs.update({k: v for k, v in canon.items() if v is not None})
    return Cls(**kwargs)

list_models(task=None)

List all registered models in the TanML ecosystem.

This function returns the metadata specifications for all models that the system is capable of building and validating.

Parameters:

Name Type Description Default
task Task | None

Optional filter. Use 'classification' for classifiers, 'regression' for regressors, or None for the full registry.

None

Returns:

Type Description
dict[tuple[str, str], ModelSpec]

A dictionary where: - Keys: (library_name, algorithm_name) strings, e.g., ("sklearn", "RandomForestClassifier"). - Values: A ModelSpec instance containing defaults and UI schema metadata.

Source code in tanml/models/registry.py
def list_models(task: Task | None = None) -> dict[tuple[str, str], ModelSpec]:
    """
    List all registered models in the TanML ecosystem.

    This function returns the metadata specifications for all models that the
    system is capable of building and validating.

    Args:
        task: Optional filter. Use 'classification' for classifiers,
            'regression' for regressors, or None for the full registry.

    Returns:
        A dictionary where:
            - Keys: (library_name, algorithm_name) strings, e.g., ("sklearn", "RandomForestClassifier").
            - Values: A `ModelSpec` instance containing defaults and UI schema metadata.
    """
    if task:
        return {k: v for k, v in _REGISTRY.items() if v.task == task}
    return dict(_REGISTRY)

get_spec(library, algo)

Retrieve the ModelSpec for a specific library and algorithm.

Parameters:

Name Type Description Default
library str

e.g., 'sklearn', 'xgboost'.

required
algo str

e.g., 'RandomForestClassifier'.

required

Returns:

Type Description
ModelSpec

The matched ModelSpec instance.

Raises:

Type Description
KeyError

If the library/algorithm combination is not registered.

Source code in tanml/models/registry.py
def get_spec(library: str, algo: str) -> ModelSpec:
    """
    Retrieve the ModelSpec for a specific library and algorithm.

    Args:
        library: e.g., 'sklearn', 'xgboost'.
        algo: e.g., 'RandomForestClassifier'.

    Returns:
        The matched ModelSpec instance.

    Raises:
        KeyError: If the library/algorithm combination is not registered.
    """
    key = (library, algo)
    if key not in _REGISTRY:
        raise KeyError(f"Unknown model: {library}.{algo}")
    return _REGISTRY[key]

Feature drift analysis module for internal and external validation.

This module provides statistical tools to detect whether the distribution of machine learning features has changed between two points in time or between two datasets (e.g., Training vs. Serving).

Key Metrics
  • PSI (Population Stability Index): A single number indicating the magnitude of the shift.
  • KS Test (Kolmogorov-Smirnov): A non-parametric test to determine if two samples come from different distributions.
Example

import pandas as pd from tanml.analysis.drift import analyze_drift

Compare Training and Serving data

results = analyze_drift(train_df, serving_df) for col, metrics in results.items(): ... if metrics["has_drift"]: ... print(f"Drift detected in {col}: PSI={metrics['psi']:.3f}")

calculate_psi(expected, actual, bins=10)

Calculate Population Stability Index (PSI) between two distributions.

PSI measures how much a distribution has shifted. Thresholds: - PSI < 0.1: No significant shift - 0.1 <= PSI < 0.2: Moderate shift (investigate) - PSI >= 0.2: Large shift (action needed)

Parameters:

Name Type Description Default
expected Series

Expected/baseline distribution (e.g., training data)

required
actual Series

Actual/new distribution (e.g., test data)

required
bins int

Number of bins for discretization

10

Returns:

Type Description
float

PSI value (float)

Source code in tanml/analysis/drift.py
def calculate_psi(
    expected: pd.Series,
    actual: pd.Series,
    bins: int = 10,
) -> float:
    """
    Calculate Population Stability Index (PSI) between two distributions.

    PSI measures how much a distribution has shifted. Thresholds:
        - PSI < 0.1: No significant shift
        - 0.1 <= PSI < 0.2: Moderate shift (investigate)
        - PSI >= 0.2: Large shift (action needed)

    Args:
        expected: Expected/baseline distribution (e.g., training data)
        actual: Actual/new distribution (e.g., test data)
        bins: Number of bins for discretization

    Returns:
        PSI value (float)
    """
    # Handle edge cases
    expected = expected.dropna()
    actual = actual.dropna()

    if len(expected) == 0 or len(actual) == 0:
        return np.nan

    # Create bins from expected distribution
    try:
        _, bin_edges = np.histogram(expected, bins=bins)
    except ValueError:
        return np.nan

    # Calculate proportions in each bin
    expected_counts = np.histogram(expected, bins=bin_edges)[0]
    actual_counts = np.histogram(actual, bins=bin_edges)[0]

    # Convert to proportions (avoid division by zero)
    expected_pct = expected_counts / len(expected)
    actual_pct = actual_counts / len(actual)

    # Replace zeros with small value to avoid log(0)
    eps = 1e-8
    expected_pct = np.where(expected_pct == 0, eps, expected_pct)
    actual_pct = np.where(actual_pct == 0, eps, actual_pct)

    # Calculate PSI
    psi = np.sum((actual_pct - expected_pct) * np.log(actual_pct / expected_pct))

    return float(psi)

calculate_ks(expected, actual)

Calculate Kolmogorov-Smirnov statistic between two distributions.

Parameters:

Name Type Description Default
expected Series

Expected/baseline distribution

required
actual Series

Actual/new distribution

required

Returns:

Type Description
tuple[float, float]

Tuple of (KS statistic, p-value)

Source code in tanml/analysis/drift.py
def calculate_ks(
    expected: pd.Series,
    actual: pd.Series,
) -> tuple[float, float]:
    """
    Calculate Kolmogorov-Smirnov statistic between two distributions.

    Args:
        expected: Expected/baseline distribution
        actual: Actual/new distribution

    Returns:
        Tuple of (KS statistic, p-value)
    """
    from scipy import stats

    expected = expected.dropna()
    actual = actual.dropna()

    if len(expected) == 0 or len(actual) == 0:
        return np.nan, np.nan

    try:
        ks_stat, p_value = stats.ks_2samp(expected, actual)
        return float(ks_stat), float(p_value)
    except Exception:
        return np.nan, np.nan

analyze_drift(train_df, test_df, numeric_cols=None, psi_threshold=0.1, ks_threshold=0.05)

Perform a comprehensive drift analysis on all continuous features.

This function iterates through the numeric columns, calculates both PSI and KS statistics, and flags features that exceed regulatory or statistical thresholds.

Parameters:

Name Type Description Default
train_df DataFrame

The baseline/reference dataset (e.g., historical training data).

required
test_df DataFrame

The target dataset to check for drift (e.g., current production batch).

required
numeric_cols list[str] | None

List of columns to analyze. If None, all numeric columns common to both datasets will be checked.

None
psi_threshold float

The PSI value above which drift is considered "moderate". Default is 0.1.

0.1
ks_threshold float

The p-value below which the KS test is considered statistically significant (identifying a difference). Default is 0.05.

0.05

Returns:

Type Description
dict[str, dict[str, Any]]

A dictionary mapping column names to their drift metadata: - psi: The calculated Population Stability Index. - ks_statistic: The Kolmogorov-Smirnov distance. - ks_pvalue: The p-value from the KS test. - has_drift: Boolean flag if PSI >= psi_threshold. - drift_level: String enum ("none", "moderate", "severe").

Source code in tanml/analysis/drift.py
def analyze_drift(
    train_df: pd.DataFrame,
    test_df: pd.DataFrame,
    numeric_cols: list[str] | None = None,
    psi_threshold: float = 0.1,
    ks_threshold: float = 0.05,
) -> dict[str, dict[str, Any]]:
    """
    Perform a comprehensive drift analysis on all continuous features.

    This function iterates through the numeric columns, calculates both
    PSI and KS statistics, and flags features that exceed regulatory or
    statistical thresholds.

    Args:
        train_df: The baseline/reference dataset (e.g., historical training data).
        test_df: The target dataset to check for drift (e.g., current production batch).
        numeric_cols: List of columns to analyze. If None, all numeric columns
            common to both datasets will be checked.
        psi_threshold: The PSI value above which drift is considered "moderate".
            Default is 0.1.
        ks_threshold: The p-value below which the KS test is considered
            statistically significant (identifying a difference). Default is 0.05.

    Returns:
        A dictionary mapping column names to their drift metadata:
            - `psi`: The calculated Population Stability Index.
            - `ks_statistic`: The Kolmogorov-Smirnov distance.
            - `ks_pvalue`: The p-value from the KS test.
            - `has_drift`: Boolean flag if PSI >= psi_threshold.
            - `drift_level`: String enum ("none", "moderate", "severe").
    """
    if numeric_cols is None:
        numeric_cols = train_df.select_dtypes(include=[np.number]).columns.tolist()

    # Get common columns
    common_cols = [c for c in numeric_cols if c in train_df.columns and c in test_df.columns]

    results = {}
    for col in common_cols:
        psi = calculate_psi(train_df[col], test_df[col])
        ks_stat, ks_pval = calculate_ks(train_df[col], test_df[col])

        # Determine drift level
        if np.isnan(psi):
            drift_level = "unknown"
            has_drift = False
        elif psi >= 0.2:
            drift_level = "severe"
            has_drift = True
        elif psi >= psi_threshold:
            drift_level = "moderate"
            has_drift = True
        else:
            drift_level = "none"
            has_drift = False

        results[col] = {
            "psi": psi,
            "ks_statistic": ks_stat,
            "ks_pvalue": ks_pval,
            "has_drift": has_drift,
            "drift_level": drift_level,
        }

    return results

Feature correlation and VIF analysis module.

Provides correlation matrix calculation and Variance Inflation Factor (VIF) for detecting multicollinearity.

Example

from tanml.analysis.correlation import calculate_vif, calculate_correlation_matrix

corr_matrix = calculate_correlation_matrix(df, method="pearson") vif_results = calculate_vif(df, features=["age", "income", "score"])

calculate_correlation_matrix(df, method='pearson', numeric_only=True)

Calculate correlation matrix for numeric features.

Parameters:

Name Type Description Default
df DataFrame

Input DataFrame

required
method str

Correlation method ("pearson", "spearman", or "kendall")

'pearson'
numeric_only bool

Whether to include only numeric columns

True

Returns:

Type Description
DataFrame

Correlation matrix as DataFrame

Source code in tanml/analysis/correlation.py
def calculate_correlation_matrix(
    df: pd.DataFrame,
    method: str = "pearson",
    numeric_only: bool = True,
) -> pd.DataFrame:
    """
    Calculate correlation matrix for numeric features.

    Args:
        df: Input DataFrame
        method: Correlation method ("pearson", "spearman", or "kendall")
        numeric_only: Whether to include only numeric columns

    Returns:
        Correlation matrix as DataFrame
    """
    if numeric_only:
        df = df.select_dtypes(include=[np.number])

    return df.corr(method=method)

find_highly_correlated_pairs(corr_matrix, threshold=0.8)

Find pairs of features with high correlation.

Parameters:

Name Type Description Default
corr_matrix DataFrame

Correlation matrix

required
threshold float

Absolute correlation threshold

0.8

Returns:

Type Description
list[dict[str, Any]]

List of dictionaries with correlated pairs

Source code in tanml/analysis/correlation.py
def find_highly_correlated_pairs(
    corr_matrix: pd.DataFrame,
    threshold: float = 0.8,
) -> list[dict[str, Any]]:
    """
    Find pairs of features with high correlation.

    Args:
        corr_matrix: Correlation matrix
        threshold: Absolute correlation threshold

    Returns:
        List of dictionaries with correlated pairs
    """
    pairs = []
    cols = corr_matrix.columns.tolist()

    for i, col1 in enumerate(cols):
        for j, col2 in enumerate(cols):
            if i < j:  # Upper triangle only
                corr = corr_matrix.loc[col1, col2]
                if abs(corr) >= threshold:
                    pairs.append(
                        {
                            "feature_1": col1,
                            "feature_2": col2,
                            "correlation": float(corr),
                        }
                    )

    # Sort by absolute correlation (highest first)
    pairs.sort(key=lambda x: abs(x["correlation"]), reverse=True)
    return pairs

calculate_vif(df, features=None, threshold=5.0)

Calculate Variance Inflation Factor (VIF) for features.

VIF measures how much the variance of a regression coefficient is inflated due to multicollinearity. Thresholds: - VIF < 5: Low multicollinearity - 5 <= VIF < 10: Moderate multicollinearity - VIF >= 10: High multicollinearity

Parameters:

Name Type Description Default
df DataFrame

Input DataFrame

required
features list[str] | None

List of features to analyze (auto-detected if None)

None
threshold float

VIF threshold for flagging

5.0

Returns:

Type Description
dict[str, Any]

Dictionary with VIF values and flagged features

Source code in tanml/analysis/correlation.py
def calculate_vif(
    df: pd.DataFrame,
    features: list[str] | None = None,
    threshold: float = 5.0,
) -> dict[str, Any]:
    """
    Calculate Variance Inflation Factor (VIF) for features.

    VIF measures how much the variance of a regression coefficient is
    inflated due to multicollinearity. Thresholds:
        - VIF < 5: Low multicollinearity
        - 5 <= VIF < 10: Moderate multicollinearity
        - VIF >= 10: High multicollinearity

    Args:
        df: Input DataFrame
        features: List of features to analyze (auto-detected if None)
        threshold: VIF threshold for flagging

    Returns:
        Dictionary with VIF values and flagged features
    """
    from statsmodels.stats.outliers_influence import variance_inflation_factor

    # Get numeric columns
    if features is None:
        features = df.select_dtypes(include=[np.number]).columns.tolist()

    # Filter to existing columns
    features = [f for f in features if f in df.columns]

    if len(features) < 2:
        return {
            "vif_values": {},
            "high_vif_features": [],
            "status": "pass",
            "error": "Need at least 2 features for VIF calculation",
        }

    # Prepare data (drop NaN rows)
    X = df[features].dropna()

    if len(X) < len(features) + 1:
        return {
            "vif_values": {},
            "high_vif_features": [],
            "status": "unknown",
            "error": "Insufficient data for VIF calculation",
        }

    # Calculate VIF for each feature
    vif_values = {}
    import warnings
    for i, col in enumerate(features):
        try:
            with warnings.catch_warnings():
                warnings.simplefilter("ignore", RuntimeWarning)
                # Silence numpy divide-by-zero warnings during VIF calculation
                old_settings = np.seterr(divide="ignore", invalid="ignore")
                try:
                    vif = variance_inflation_factor(X.values, i)
                finally:
                    np.seterr(**old_settings)
            vif_values[col] = float(vif) if not np.isinf(vif) else float("inf")
        except Exception:
            vif_values[col] = np.nan

    # Find high VIF features
    high_vif = [
        col
        for col, vif in vif_values.items()
        if vif is not None and not np.isnan(vif) and vif >= threshold
    ]

    return {
        "vif_values": vif_values,
        "high_vif_features": high_vif,
        "high_vif_count": len(high_vif),
        "threshold": threshold,
        "status": "warning" if high_vif else "pass",
    }

analyze_feature_relationships(df, features=None, corr_method='pearson', corr_threshold=0.8, vif_threshold=5.0)

Comprehensive feature relationship analysis.

Combines correlation and VIF analysis for a complete picture of feature relationships.

Parameters:

Name Type Description Default
df DataFrame

Input DataFrame

required
features list[str] | None

List of features to analyze

None
corr_method str

Correlation method

'pearson'
corr_threshold float

Threshold for flagging high correlations

0.8
vif_threshold float

Threshold for flagging high VIF

5.0

Returns:

Type Description
dict[str, Any]

Combined analysis results

Source code in tanml/analysis/correlation.py
def analyze_feature_relationships(
    df: pd.DataFrame,
    features: list[str] | None = None,
    corr_method: str = "pearson",
    corr_threshold: float = 0.8,
    vif_threshold: float = 5.0,
) -> dict[str, Any]:
    """
    Comprehensive feature relationship analysis.

    Combines correlation and VIF analysis for a complete picture
    of feature relationships.

    Args:
        df: Input DataFrame
        features: List of features to analyze
        corr_method: Correlation method
        corr_threshold: Threshold for flagging high correlations
        vif_threshold: Threshold for flagging high VIF

    Returns:
        Combined analysis results
    """
    if features is None:
        features = df.select_dtypes(include=[np.number]).columns.tolist()

    df_subset = df[features]

    # Correlation analysis
    corr_matrix = calculate_correlation_matrix(df_subset, method=corr_method)
    high_corr_pairs = find_highly_correlated_pairs(corr_matrix, threshold=corr_threshold)

    # VIF analysis
    vif_results = calculate_vif(df_subset, features=features, threshold=vif_threshold)

    # Determine overall status
    has_high_corr = len(high_corr_pairs) > 0
    has_high_vif = len(vif_results.get("high_vif_features", [])) > 0

    if has_high_corr or has_high_vif:
        status = "warning"
    else:
        status = "pass"

    return {
        "correlation_matrix": corr_matrix.to_dict(),
        "high_correlation_pairs": high_corr_pairs,
        "vif": vif_results,
        "status": status,
        "summary": {
            "n_features": len(features),
            "n_high_corr_pairs": len(high_corr_pairs),
            "n_high_vif_features": len(vif_results.get("high_vif_features", [])),
        },
    }

Input cluster coverage analysis module.

Analyzes whether test data falls within the same input space as training data using clustering techniques.

Example

from tanml.analysis.clustering import analyze_cluster_coverage

coverage = analyze_cluster_coverage( X_train=train_features, X_test=test_features, n_clusters=5, )

print(f"Coverage: {coverage['coverage_pct']:.1f}%")

analyze_cluster_coverage(X_train, X_test, n_clusters=5, max_k=10, auto_select_k=False)

Analyze how well test data is covered by training data clusters.

This check identifies whether test samples fall into regions of the input space that were seen during training.

Parameters:

Name Type Description Default
X_train DataFrame

Training features

required
X_test DataFrame

Test features

required
n_clusters int

Number of clusters (if auto_select_k=False)

5
max_k int

Maximum clusters to try (if auto_select_k=True)

10
auto_select_k bool

Whether to auto-select optimal k using elbow method

False

Returns:

Type Description
dict[str, Any]

Dictionary with: - coverage_pct: Percentage of test samples in training clusters - cluster_distribution: Test samples per cluster - uncovered_indices: Indices of uncovered test samples - n_clusters: Actual number of clusters used - pca_coords: 2D PCA coordinates for visualization

Source code in tanml/analysis/clustering.py
def analyze_cluster_coverage(
    X_train: pd.DataFrame,
    X_test: pd.DataFrame,
    n_clusters: int = 5,
    max_k: int = 10,
    auto_select_k: bool = False,
) -> dict[str, Any]:
    """
    Analyze how well test data is covered by training data clusters.

    This check identifies whether test samples fall into regions of
    the input space that were seen during training.

    Args:
        X_train: Training features
        X_test: Test features
        n_clusters: Number of clusters (if auto_select_k=False)
        max_k: Maximum clusters to try (if auto_select_k=True)
        auto_select_k: Whether to auto-select optimal k using elbow method

    Returns:
        Dictionary with:
            - coverage_pct: Percentage of test samples in training clusters
            - cluster_distribution: Test samples per cluster
            - uncovered_indices: Indices of uncovered test samples
            - n_clusters: Actual number of clusters used
            - pca_coords: 2D PCA coordinates for visualization
    """
    from sklearn.cluster import KMeans
    from sklearn.decomposition import PCA
    from sklearn.preprocessing import StandardScaler

    # Get common numeric columns
    numeric_train = X_train.select_dtypes(include=[np.number])
    numeric_test = X_test.select_dtypes(include=[np.number])
    common_cols = list(set(numeric_train.columns) & set(numeric_test.columns))

    if not common_cols:
        return {
            "coverage_pct": 0.0,
            "cluster_distribution": {},
            "uncovered_indices": [],
            "n_clusters": 0,
            "error": "No common numeric columns found",
        }

    X_train_subset = numeric_train[common_cols].dropna()
    X_test_subset = numeric_test[common_cols].dropna()

    if len(X_train_subset) < n_clusters or len(X_test_subset) == 0:
        return {
            "coverage_pct": 0.0,
            "cluster_distribution": {},
            "uncovered_indices": list(range(len(X_test_subset))),
            "n_clusters": 0,
            "error": "Insufficient data for clustering",
        }

    # Standardize features
    scaler = StandardScaler()
    X_train_scaled = scaler.fit_transform(X_train_subset)
    X_test_scaled = scaler.transform(X_test_subset)

    # Auto-select k using elbow method if requested
    if auto_select_k:
        n_clusters = _select_optimal_k(X_train_scaled, max_k)

    # Fit KMeans on training data
    kmeans = KMeans(n_clusters=n_clusters, random_state=42, n_init=10)
    train_labels = kmeans.fit_predict(X_train_scaled)
    test_labels = kmeans.predict(X_test_scaled)

    # Calculate distances to nearest cluster center
    train_distances = kmeans.transform(X_train_scaled).min(axis=1)
    test_distances = kmeans.transform(X_test_scaled).min(axis=1)

    # Define coverage threshold as max training distance (with buffer)
    threshold = np.percentile(train_distances, 95) * 1.5

    # Identify uncovered test samples
    uncovered_mask = test_distances > threshold
    uncovered_indices = np.where(uncovered_mask)[0].tolist()

    coverage_pct = 100 * (1 - uncovered_mask.mean())

    # Cluster distribution
    cluster_dist = {}
    for i in range(n_clusters):
        train_count = (train_labels == i).sum()
        test_count = (test_labels == i).sum()
        cluster_dist[i] = {
            "train_count": int(train_count),
            "test_count": int(test_count),
            "train_pct": float(100 * train_count / len(train_labels)),
            "test_pct": float(100 * test_count / len(test_labels)),
        }

    # PCA for visualization
    pca = PCA(n_components=2)
    train_pca = pca.fit_transform(X_train_scaled)
    test_pca = pca.transform(X_test_scaled)

    return {
        "coverage_pct": float(coverage_pct),
        "cluster_distribution": cluster_dist,
        "uncovered_indices": uncovered_indices,
        "uncovered_count": len(uncovered_indices),
        "n_clusters": n_clusters,
        "train_labels": train_labels.tolist(),
        "test_labels": test_labels.tolist(),
        "train_pca": train_pca.tolist(),
        "test_pca": test_pca.tolist(),
        "cluster_centers_pca": pca.transform(kmeans.cluster_centers_).tolist(),
        "status": "pass" if coverage_pct >= 90 else ("warning" if coverage_pct >= 70 else "fail"),
    }

StressTestCheck

Task-aware stress test
  • Classification: accuracy, auc, delta_accuracy, delta_auc
  • Regression: rmse, r2, delta_rmse, delta_r2

For each numeric feature, perturb a random subset of rows by (1 ± epsilon).

Source code in tanml/checks/stress_test.py
class StressTestCheck:
    """
    Task-aware stress test:
      - Classification: accuracy, auc, delta_accuracy, delta_auc
      - Regression:     rmse, r2,  delta_rmse,     delta_r2

    For each numeric feature, perturb a random subset of rows by (1 ± epsilon).
    """

    def __init__(
        self,
        model,
        X,
        y,
        epsilon: float = 0.01,
        perturb_fraction: float = 0.2,
        random_state: int = 42,
    ):
        self.model = model
        self.X = pd.DataFrame(X, columns=getattr(X, "columns", None))
        self.y = np.asarray(y)
        self.epsilon = float(epsilon)
        self.perturb_fraction = float(perturb_fraction)
        self.rng = np.random.default_rng(int(random_state))

        # 🔧 Cast ALL numeric columns to float once to avoid int64→float assignment warnings
        num_cols = [
            c
            for c in self.X.columns
            if is_numeric_dtype(self.X[c]) and not is_bool_dtype(self.X[c])
        ]
        if num_cols:
            self.X[num_cols] = self.X[num_cols].astype("float64")

    def _numeric_cols(self) -> list[str]:
        return [
            c
            for c in self.X.columns
            if is_numeric_dtype(self.X[c]) and not is_bool_dtype(self.X[c])
        ]

    def _perturb_scaled(self, X: pd.DataFrame, col: str, sign: int) -> pd.DataFrame:
        """Scale a random subset of column 'col' by (1 + sign*epsilon)."""
        Xp = X.copy(deep=True)
        if Xp.empty:
            return Xp
        n = len(Xp)
        k = max(1, int(self.perturb_fraction * n))
        idx = self.rng.choice(Xp.index, size=k, replace=False)
        factor = 1.0 + sign * self.epsilon

        # Use a float numpy view for assignment — no dtype warnings
        vals = Xp.loc[idx, col].to_numpy(dtype="float64", copy=False)
        Xp.loc[idx, col] = vals * float(factor)
        return Xp

    def run(self):
        task_type = _infer_task_type(self.model, self.y)
        results: list[dict[str, Any]] = []

        # ---------- Baseline ----------
        if task_type == "regression":
            y_pred_base = np.ravel(self.model.predict(self.X))
            rmse_base, r2_base = _reg_metrics(self.y, y_pred_base)
        else:
            y_score_base = _scores_for_classification(self.model, self.X)
            # If scores are probs/decision, bin properly; else use model.predict
            try:
                y_pred_base = _bin_pred_from_score(y_score_base)
            except Exception:
                y_pred_base = np.ravel(self.model.predict(self.X))
            acc_base, auc_base = _cls_metrics(self.y, y_score_base, y_pred_base)

        # ---------- Per-feature perturbations ----------
        for col in self._numeric_cols():
            for sign, lab in [
                (+1, f"+{round(self.epsilon * 100, 2)}%"),
                (-1, f"-{round(self.epsilon * 100, 2)}%"),
            ]:
                try:
                    Xp = self._perturb_scaled(self.X, col, sign)

                    if task_type == "regression":
                        y_pred_p = np.ravel(self.model.predict(Xp))
                        rmse_p, r2_p = _reg_metrics(self.y, y_pred_p)
                        results.append(
                            {
                                "feature": col,
                                "perturbation": lab,
                                "rmse": round(rmse_p, 4),
                                "r2": round(r2_p, 4),
                                "delta_rmse": round(rmse_p - rmse_base, 4),
                                "delta_r2": round(r2_p - r2_base, 4),
                            }
                        )
                    else:
                        y_score_p = _scores_for_classification(self.model, Xp)
                        try:
                            y_pred_p = _bin_pred_from_score(y_score_p)
                        except Exception:
                            y_pred_p = np.ravel(self.model.predict(Xp))
                        acc_p, auc_p = _cls_metrics(self.y, y_score_p, y_pred_p)
                        results.append(
                            {
                                "feature": col,
                                "perturbation": lab,
                                "accuracy": round(acc_p, 4),
                                "auc": round(auc_p, 4) if auc_p == auc_p else np.nan,
                                "delta_accuracy": round(acc_p - acc_base, 4),
                                "delta_auc": round((auc_p - auc_base), 4)
                                if (auc_base == auc_base and auc_p == auc_p)
                                else np.nan,
                            }
                        )

                # Robust error row in either mode
                except Exception as e:
                    if task_type == "regression":
                        results.append(
                            {
                                "feature": col,
                                "perturbation": lab,
                                "rmse": "error",
                                "r2": "error",
                                "delta_rmse": f"Error: {e}",
                                "delta_r2": f"Error: {e}",
                            }
                        )
                    else:
                        results.append(
                            {
                                "feature": col,
                                "perturbation": lab,
                                "accuracy": "error",
                                "auc": "error",
                                "delta_accuracy": f"Error: {e}",
                                "delta_auc": f"Error: {e}",
                            }
                        )

        return pd.DataFrame(results)

SHAPCheck

Bases: BaseCheck

SHAP for regression + binary classification (no multiclass).

Config under rule_config["explainability"]["shap"]: - enabled: bool (default True, also checked under rule_config["SHAPCheck"]["enabled"]) - task: "auto" | "classification" | "regression" (default "auto") - algorithm: "auto" | "tree" | "linear" | "kernel" | "permutation" (default "auto") - model_output: "auto" | "raw" | "log_odds" | "probability" (tree-only hint; default "auto") - background_strategy: "sample" | "kmeans" (default "sample") - background_sample_size: int (default 100) - test_sample_size: int (default 200) - max_display: int (default 20) - seed: int (default 42) - out_dir: str (optional) (preferred save folder)

Source code in tanml/checks/explainability/shap_check.py
class SHAPCheck(BaseCheck):
    """
    SHAP for regression + binary classification (no multiclass).

    Config under rule_config["explainability"]["shap"]:
      - enabled: bool (default True, also checked under rule_config["SHAPCheck"]["enabled"])
      - task: "auto" | "classification" | "regression"          (default "auto")
      - algorithm: "auto" | "tree" | "linear" | "kernel" | "permutation" (default "auto")
      - model_output: "auto" | "raw" | "log_odds" | "probability"  (tree-only hint; default "auto")
      - background_strategy: "sample" | "kmeans"                 (default "sample")
      - background_sample_size: int                              (default 100)
      - test_sample_size: int                                    (default 200)
      - max_display: int                                         (default 20)
      - seed: int                                                (default 42)
      - out_dir: str (optional)                                  (preferred save folder)
    """

    def __init__(self, model, X_train, X_test, y_train, y_test, rule_config=None, cleaned_df=None):
        super().__init__(model, X_train, X_test, y_train, y_test, rule_config)
        self.cleaned_df = cleaned_df

    # -------------------------- helpers --------------------------

    @staticmethod
    def _to_df(X, names=None):
        if isinstance(X, pd.DataFrame):
            return X
        if sp.issparse(X):
            df = pd.DataFrame.sparse.from_spmatrix(X)
        else:
            df = pd.DataFrame(np.asarray(X))
        if names is not None and len(names) == df.shape[1]:
            df.columns = list(names)
        return df

    @staticmethod
    def _task(y, forced="auto"):
        if forced and forced != "auto":
            return forced
        try:
            yv = (
                y.iloc[:, 0]
                if isinstance(y, pd.DataFrame)
                else (y if isinstance(y, pd.Series) else pd.Series(y))
            )
            uniq = pd.Series(yv).dropna().unique()
            return "classification" if len(uniq) <= 2 else "regression"
        except Exception:
            return "regression"

    @staticmethod
    def _pos_cls_idx(model, X_one):
        """Return index of the positive class (1/True if available, else max-label) for binary classification."""
        try:
            if hasattr(model, "classes_") and len(model.classes_) == 2:
                classes = list(model.classes_)
                for pos in (1, True):
                    if pos in classes:
                        return classes.index(pos)
                return classes.index(max(classes))
            if hasattr(model, "predict_proba"):
                proba = model.predict_proba(X_one)
                return 1 if proba.shape[1] == 2 else 0
        except Exception:
            pass
        return 1

    @staticmethod
    def _looks_like_tree(m):
        mod = type(m).__module__.lower()
        name = type(m).__name__.lower()
        return (
            "xgboost" in mod
            or "lightgbm" in mod
            or "catboost" in mod
            or "sklearn.ensemble" in mod
            or "sklearn.tree" in mod
            or "randomforest" in name
            or "gradientboost" in name
            or "extratrees" in name
            or "decisiontree" in name
        )

    @staticmethod
    def _looks_like_linear(m):
        mod = type(m).__module__.lower()
        name = type(m).__name__.lower()
        return (
            "sklearn.linear_model" in mod
            or "logistic" in name
            or "linear" in name
            or "ridge" in name
            or "lasso" in name
            or "elastic" in name
        )

    def _predict_fn(self, is_cls: bool, pos_idx: int | None):
        """
        Vectorized prediction function for permutation/kernel explainers.
        Returns positive-class probability when classification is detected and predict_proba is available.
        """
        if is_cls and hasattr(self.model, "predict_proba"):

            def f(X):
                p = self.model.predict_proba(X)
                i = 1 if (p.ndim == 2 and p.shape[1] == 2) else (pos_idx or 0)
                return p[:, i]

            return f
        return self.model.predict

    def _explainer(self, algorithm, background, model_output_hint, is_cls, pos_idx):
        """
        Choose fastest viable explainer:
          - tree → TreeExplainer (with interventional perturbation)
          - linear → LinearExplainer
          - default/auto → PermutationExplainer (avoid slow Kernel, unless explicitly requested)
        """
        m = self.model
        alg = (algorithm or "auto").lower()

        # Prefer fast paths in auto
        if alg == "tree" or (alg == "auto" and self._looks_like_tree(m)):
            mo = "raw" if model_output_hint == "auto" else model_output_hint
            try:
                expl = shap.TreeExplainer(
                    m, data=background, feature_perturbation="interventional", model_output=mo
                )
            except Exception:
                # Fallback for models with categorical splits (XGBoost/CatBoost)
                # which fail with interventional perturbation + background data
                expl = shap.TreeExplainer(
                    m, feature_perturbation="tree_path_dependent", model_output=mo
                )
            return expl, "tree"

        if alg == "linear" or (alg == "auto" and self._looks_like_linear(m)):
            return shap.LinearExplainer(m, background), "linear"

        if alg in {"permutation", "auto"}:
            fn = self._predict_fn(is_cls, pos_idx)
            return shap.explainers.Permutation(fn, background, max_evals=2000), "perm"

        # Only use Kernel if explicitly requested
        if alg == "kernel":
            fn = self._predict_fn(is_cls, pos_idx)
            return shap.KernelExplainer(fn, background), "kernel"

        # Fallback (should not hit)
        fn = self._predict_fn(is_cls, pos_idx)
        return shap.explainers.Permutation(fn, background, max_evals=2000), "perm"

    # ---------------------------- main ----------------------------

    def run(self):
        import matplotlib

        matplotlib.use("Agg")
        import matplotlib.pyplot as plt

        out = {}
        try:
            warnings.filterwarnings("ignore", category=UserWarning)

            # -------- read config (note: under explainability.shap) ----------
            exp_cfg = (self.config or {}).get("explainability", {}) or {}
            cfg = exp_cfg.get("shap", {}) if isinstance(exp_cfg, dict) else {}
            seed = int(cfg.get("seed", 42))
            bg_n = int(cfg.get("background_sample_size", 100))
            test_n = int(cfg.get("test_sample_size", 200))
            task_forced = (cfg.get("task") or "auto").lower()
            algorithm = (cfg.get("algorithm") or "auto").lower()
            bg_strategy = (cfg.get("background_strategy") or "sample").lower()
            model_output_hint = (cfg.get("model_output") or "auto").lower()
            max_display = int(cfg.get("max_display", 20))

            # -------- resolve output directory + timestamp ----------
            out_dir_opt = cfg.get("out_dir")
            options_dir = ((self.config or {}).get("options") or {}).get("save_artifacts_dir")
            # prefer explicit shap.out_dir, then global artifacts dir, then local fallback
            outdir = Path(
                out_dir_opt
                or options_dir
                or (Path(__file__).resolve().parents[2] / "tmp_report_assets")
            )
            outdir.mkdir(parents=True, exist_ok=True)
            ts = datetime.now().strftime("%Y%m%d_%H%M%S")

            # -------- materialize dataframes & coerce dtypes ----------
            feature_names = (
                list(self.X_train.columns) if isinstance(self.X_train, pd.DataFrame) else None
            )
            Xtr = self._to_df(self.X_train, feature_names)
            Xte = self._to_df(self.X_test, feature_names)
            if Xtr.empty or Xte.empty:
                raise ValueError("Empty X_train or X_test for SHAP.")

            # Avoid slow implicit object->float conversions during plotting
            Xtr = _safe_numeric_cast_df(Xtr)
            Xte = _safe_numeric_cast_df(Xte)

            # -------- task resolution & sanity for binary classification ----------
            task = self._task(self.y_train, forced=task_forced)
            is_cls = task == "classification"
            if is_cls:
                yv = (
                    self.y_train
                    if isinstance(self.y_train, (pd.Series, pd.DataFrame))
                    else pd.Series(self.y_train)
                )
                if len(pd.Series(yv).dropna().unique()) > 2:
                    raise ValueError("Binary classification only: y_train has >2 classes.")

            # positive-class index hint (used for permutation/kernel predict function)
            pos_idx_hint = self._pos_cls_idx(self.model, Xte.iloc[:1].values)

            # -------- background selection ----------
            if bg_strategy == "kmeans" and len(Xtr) > bg_n and not sp.issparse(self.X_train):
                background = shap.kmeans(Xtr, bg_n, seed=seed)
            else:
                background = shap.utils.sample(Xtr, bg_n, random_state=seed)

            # -------- slice test rows to explain ----------
            Xs = Xte.head(test_n)

            # -------- choose explainer & compute SHAP once ----------
            explainer, kind = self._explainer(
                algorithm, background, model_output_hint, is_cls, pos_idx_hint
            )
            if kind == "tree":
                sv = explainer(Xs, check_additivity=False)  # big speedup, visually identical plots
            else:
                sv = explainer(Xs)

            bg_shape = background.shape if hasattr(background, "shape") else None
            print(
                f"SHAP explainer={type(explainer).__name__} kind={kind} Xs={Xs.shape} "
                f"bg={'kmeans' if bg_shape is None else bg_shape}"
            )

            # -------- squeeze to 2-D for binary cls (if needed) ----------
            vals = sv.values
            if hasattr(vals, "ndim") and vals.ndim == 3:
                pos_idx = self._pos_cls_idx(self.model, Xs.iloc[:1].values)
                sv.values = vals[:, :, pos_idx]
                if isinstance(sv.base_values, np.ndarray) and sv.base_values.ndim == 2:
                    sv.base_values = sv.base_values[:, pos_idx]
            else:
                pos_idx = (
                    None
                    if task == "regression"
                    else self._pos_cls_idx(self.model, Xs.iloc[:1].values)
                )

            # -------- Python 3.13 fix: ensure SHAP values are proper numpy arrays ----------
            # SHAP beeswarm plot fails in Python 3.13 if values are Python floats instead of numpy floats
            if not isinstance(sv.values, np.ndarray):
                sv.values = np.asarray(sv.values, dtype=np.float64)
            elif sv.values.dtype == object:
                sv.values = sv.values.astype(np.float64)

            if hasattr(sv, "base_values") and sv.base_values is not None:
                if not isinstance(sv.base_values, np.ndarray):
                    sv.base_values = np.asarray(sv.base_values, dtype=np.float64)
                elif hasattr(sv.base_values, "dtype") and sv.base_values.dtype == object:
                    sv.base_values = sv.base_values.astype(np.float64)

            # -------- save plots ----------
            segment = "global"

            beeswarm_path = outdir / f"shap_beeswarm_{segment}_{ts}.png"
            plt.figure(figsize=(9, 6))
            shap.plots.beeswarm(sv, max_display=max_display, show=False)
            plt.tight_layout()
            plt.savefig(beeswarm_path, bbox_inches="tight", dpi=120, transparent=False)
            plt.close()

            bar_path = outdir / f"shap_bar_{segment}_{ts}.png"
            plt.figure(figsize=(9, 6))
            shap.plots.bar(sv, max_display=max_display, show=False)
            plt.tight_layout()
            plt.savefig(bar_path, bbox_inches="tight", dpi=120, transparent=False)
            plt.close()

            # -------- top features ----------
            # mean absolute SHAP across rows
            mean_abs = np.abs(sv.values).mean(axis=0)
            idx = np.argsort(mean_abs)[::-1][:max_display]
            cols = list(Xs.columns)
            top = [
                {
                    "feature": cols[i] if i < len(cols) else f"f{i}",
                    "mean_abs_shap": float(mean_abs[i]),
                }
                for i in idx
            ]
            top_list_pairs = [
                [d["feature"], d["mean_abs_shap"]] for d in top
            ]  # compat with old report builder

            out.update(
                {
                    "status": "ok",
                    "task": task,
                    "positive_class_index": pos_idx if task == "classification" else None,
                    # new + old keys (backward-compatible)
                    "plots": {"beeswarm": str(beeswarm_path), "bar": str(bar_path)},
                    "images": {"beeswarm": str(beeswarm_path), "bar": str(bar_path)},
                    "top_features": top,
                    "shap_top_features": top_list_pairs,
                }
            )
            print(f"✅ SHAP saved: {beeswarm_path}, {bar_path}")

        except Exception:
            out["status"] = "error: " + traceback.format_exc()
        return out