Model Classes

The core model classes perform statistical computations independently of AnnData. They follow a scikit-learn-style API with .fit() and .predict() methods, and are useful when you need direct control over the analysis process or are working with custom data formats.

DifferentialAbundance

class kompot.DifferentialAbundance(log_fold_change_threshold: float = 1.0, ptp_threshold: float = 0.05, n_landmarks: int | None = None, use_sample_variance: bool | None = None, eps: float = 1e-12, jit_compile: bool = False, density_predictor1: Any | None = None, density_predictor2: Any | None = None, variance_predictor1: Any | None = None, variance_predictor2: Any | None = None, random_state: int | None = None, batch_size: int | None = None)View on GitHub

Bases: object

Compute differential abundance between two conditions.

This class analyzes the differences in cell density between two conditions (e.g., control to treatment) using density estimation and fold change analysis.

The analysis can be performed with synchronized parameters between conditions by setting sync_parameters=True in the fit method, which ensures consistent density estimation across both conditions.

log_density_condition1

Log density values for the first condition.

Type:

np.ndarray

log_density_condition2

Log density values for the second condition.

Type:

np.ndarray

log_fold_change

Log fold change between conditions (condition2 - condition1).

Type:

np.ndarray

log_fold_change_uncertainty

Uncertainty in the log fold change estimates.

Type:

np.ndarray

log_fold_change_zscore

Z-scores for the log fold changes.

Type:

np.ndarray

log_fold_change_ptp

PTP (Posterior Tail Probability) for the log fold changes. The PTP is the significance measure similar to p-value.

Type:

np.ndarray

log_fold_change_direction

Direction of change (‘up’, ‘down’, or ‘neutral’) based on thresholds.

Type:

np.ndarray

fit(X_condition1, X_condition2, sync_parameters=False, \*\*density_kwargs)View on GitHub

Fit density estimators for both conditions, optionally with synchronized parameters.

predict(X_new)View on GitHub

Predict log density and log fold change for new points.

fit(X_condition1: ndarray, X_condition2: ndarray, landmarks: ndarray | None = None, ls_factor: float = 10.0, condition1_sample_indices: ndarray | None = None, condition2_sample_indices: ndarray | None = None, sample_estimator_ls: float | None = None, sync_parameters: bool = False, allow_single_condition_variance: bool = False, **density_kwargs)View on GitHub

Fit density estimators for both conditions.

This method only creates the estimators and does not compute fold changes. Call predict() to compute fold changes on any set of points.

Parameters:
  • X_condition1 (np.ndarray) – Cell states for the first condition. Shape (n_cells, n_features).

  • X_condition2 (np.ndarray) – Cell states for the second condition. Shape (n_cells, n_features).

  • landmarks (np.ndarray, optional) – Pre-computed landmarks to use. If provided, n_landmarks will be ignored. Shape (n_landmarks, n_features).

  • ls_factor (float, optional) – Multiplication factor to apply to length scale when it’s automatically inferred, by default 10.0. Only used when ls is not explicitly provided in density_kwargs.

  • condition1_sample_indices (np.ndarray, optional) – Sample indices for first condition. Used for sample variance estimation. Unique values in this array define different sample groups.

  • condition2_sample_indices (np.ndarray, optional) – Sample indices for second condition. Used for sample variance estimation. Unique values in this array define different sample groups.

  • sample_estimator_ls (float, optional) – Length scale for the sample-specific variance estimators. If None, will use the same value as ls or it will be estimated, by default None.

  • sync_parameters (bool, optional) – Whether to synchronize model parameters (d, mu, ls) between both conditions using the combined dataset. When True, parameters are computed once from the combined data to ensure models for both conditions use identical parameter values. This is especially important for consistent density estimation across conditions. Default is False.

  • **density_kwargs (dict) – Additional arguments to pass to the DensityEstimator.

Returns:

The fitted instance.

Return type:

self

predict(X_new: ndarray, log_fold_change_threshold: float | None = None, ptp_threshold: float | None = None, progress: bool = True) Dict[str, ndarray]View on GitHub

Predict log density and log fold change for new points.

This method computes all fold changes and related metrics. It uses internal batching for efficient computation with large datasets.

Parameters:
  • X_new (np.ndarray) – New cell states to predict. Shape (n_cells, n_features).

  • log_fold_change_threshold (float, optional) – Threshold for considering a log fold change significant. If None, uses the threshold specified during initialization.

  • ptp_threshold (float, optional) – Threshold for considering a PTP (Posterior Tail Probability) significant. If None, uses the threshold specified during initialization.

  • progress (bool, optional) – Whether to show progress bars for operations, by default True.

Returns:

Dictionary containing the predictions: - ‘log_density_condition1’: Log density for condition 1 - ‘log_density_condition2’: Log density for condition 2 - ‘log_fold_change’: Log fold change between conditions - ‘log_fold_change_uncertainty’: Uncertainty in the log fold change - ‘log_fold_change_zscore’: Z-scores for the log fold change - ‘neg_log10_fold_change_ptp’: Negative log10 PTP (Posterior Tail Probability) for the log fold change - ‘log_fold_change_direction’: Direction of change (‘up’, ‘down’, or ‘neutral’)

Return type:

dict

DifferentialExpression

class kompot.DifferentialExpression(n_landmarks: int | None = None, use_sample_variance: bool | None = None, use_empirical_variance: bool = True, eps: float = 1e-08, jit_compile: bool = False, function_predictor1: Any | None = None, function_predictor2: Any | None = None, variance_predictor1: Any | None = None, variance_predictor2: Any | None = None, obs_variance_predictor1: Any | None = None, obs_variance_predictor2: Any | None = None, random_state: int | None = None, batch_size: int = 500, store_arrays_on_disk: bool | None = None, disk_storage_dir: str | None = None, max_memory_ratio: float = 0.8, model1: ExpressionModel | None = None, model2: ExpressionModel | None = None)View on GitHub

Bases: object

Compute differential expression between two conditions.

This class analyzes the differences in gene expression between two conditions (e.g., control to treatment) using imputation, Mahalanobis distance, and log fold change analysis.

function_predictor1View on GitHub

Function predictor for condition 1.

Type:

Callable

function_predictor2View on GitHub

Function predictor for condition 2.

Type:

Callable

variance_predictor1View on GitHub

Variance predictor for condition 1. If provided, will be used for uncertainty calculation.

Type:

Callable, optional

variance_predictor2View on GitHub

Variance predictor for condition 2. If provided, will be used for uncertainty calculation.

Type:

Callable, optional

mahalanobis_distances

Mahalanobis distances for each gene.

Type:

np.ndarray

compute_fdr(null_mahalanobis, threshold=0.05, gene_names=None)View on GitHub

Compute FDR for the last predict() using external null distances.

Parameters:
  • null_mahalanobis (np.ndarray) – Null Mahalanobis distances.

  • threshold (float) – FDR threshold for the is_de column.

  • gene_names (list of str, optional) – Gene names for the DataFrame index.

Returns:

DataFrame with columns: mahalanobis, pvalue, local_fdr, tail_fdr, is_de.

Return type:

pd.DataFrame

compute_mahalanobis_distances(X: ndarray, fold_change=None, use_landmarks: bool = True, landmarks_override: ndarray | None = None, progress: bool = True) ndarrayView on GitHub

Compute Mahalanobis distances for each gene using efficient matrix preparation and batching.

Parameters:
  • X (np.ndarray) – Cell states. Shape (n_cells, n_features).

  • fold_change (np.ndarray, optional) – Pre-computed fold change matrix. If None, will compute it. Shape (n_cells, n_genes).

  • use_landmarks (bool, optional) – Whether to use landmarks for covariance calculation if available, by default True.

  • landmarks_override (np.ndarray, optional) – Explicitly provided landmarks to use instead of automatically detected ones, by default None.

  • progress (bool, optional) – Whether to show tqdm.auto progress bars during Mahalanobis distance computation. When True, displays progress bars for gene-wise operations. When False, progress bars are disabled. Default is True.

Returns:

Array of Mahalanobis distances for each gene.

Return type:

np.ndarray

property empirical_variance_predictor1View on GitHub

Empirical (obs) variance predictor for condition 1.

property empirical_variance_predictor2View on GitHub

Empirical (obs) variance predictor for condition 2.

property expression_estimator_condition1View on GitHub

Mellon FunctionEstimator for condition 1.

property expression_estimator_condition2View on GitHub

Mellon FunctionEstimator for condition 2.

fit(X_condition1: ndarray, y_condition1: ndarray, X_condition2: ndarray, y_condition2: ndarray, sigma: float = 1.0, ls: float | None = None, ls_factor: float = 10.0, landmarks: ndarray | None = None, sample_estimator_ls: float | None = None, condition1_sample_indices: ndarray | None = None, condition2_sample_indices: ndarray | None = None, allow_single_condition_variance: bool = False, **function_kwargs)View on GitHub

Fit function estimators for both conditions.

This method only creates the estimators and does not compute fold changes. Call predict() to compute fold changes on any set of points.

Parameters:
  • X_condition1 (np.ndarray) – Cell states for the first condition. Shape (n_cells1, n_features).

  • y_condition1 (np.ndarray) – Gene expression values for the first condition. Shape (n_cells1, n_genes).

  • X_condition2 (np.ndarray) – Cell states for the second condition. Shape (n_cells2, n_features).

  • y_condition2 (np.ndarray) – Gene expression values for the second condition. Shape (n_cells2, n_genes).

  • sigma (float, optional) – Noise level for function estimator, by default 1.0.

  • ls (float, optional) – Length scale for the GP kernel. If None, it will be estimated, by default None.

  • ls_factor (float, optional) – Multiplication factor to apply to length scale when it’s automatically inferred, by default 10.0. Only used when ls is None.

  • landmarks (np.ndarray, optional) – Pre-computed landmarks to use. If provided, n_landmarks will be ignored. Shape (n_landmarks, n_features).

  • sample_estimator_ls (float, optional) – Length scale for the sample-specific variance estimators. If None, will use the same value as ls or it will be estimated, by default None.

  • condition1_sample_indices (np.ndarray, optional) – Sample indices for first condition. Used for sample variance estimation. Unique values in this array define different sample groups.

  • condition2_sample_indices (np.ndarray, optional) – Sample indices for second condition. Used for sample variance estimation. Unique values in this array define different sample groups.

  • **function_kwargs (dict) – Additional arguments to pass to the FunctionEstimator.

Returns:

The fitted instance.

Return type:

self

property function_predictor1View on GitHub

Function predictor for condition 1.

property function_predictor2View on GitHub

Function predictor for condition 2.

predict(X_new: ndarray, compute_mahalanobis: bool = False, progress: bool = True, use_landmarks: bool = True, landmarks_override: ndarray | None = None) Dict[str, ndarray]View on GitHub

Predict gene expression and differential metrics for new points.

This method computes fold changes and related metrics for the provided points. It uses internal batching for efficient computation with large datasets.

Parameters:
  • X_new (np.ndarray) – New cell states. Shape (n_cells, n_features).

  • compute_mahalanobis (bool, optional) – Whether to compute Mahalanobis distances. This can be computationally expensive, so it’s optional in the predict method. Default is False.

  • progress (bool, optional) – Whether to show tqdm.auto progress bars during computation. When True, displays progress bars for all batch processing operations including prediction, uncertainty computation, and Mahalanobis distance calculations. When False, all progress bars are disabled. Default is True.

  • use_landmarks (bool, optional) – Whether to use landmarks for Mahalanobis distance calculation if available, by default True. Setting to False will force computation using all provided points, which can be more accurate for small datasets or subsets.

  • landmarks_override (np.ndarray, optional) – Explicitly provided landmarks to use instead of the ones from the fitted model. Shape (n_landmarks, n_features). Used when custom landmarks are needed for a specific prediction, such as when analyzing a subset of data.

Returns:

Dictionary containing the predictions: - ‘condition1_smoothed’: Smoothed expression for condition 1 - ‘condition2_smoothed’: Smoothed expression for condition 2 - ‘condition1_std’: Posterior standard deviation for condition 1 - ‘condition2_std’: Posterior standard deviation for condition 2 - ‘fold_change’: Fold change between conditions - ‘mean_log_fold_change’: Mean log fold change across all cells - ‘mahalanobis_distances’: Only if compute_mahalanobis is True

Return type:

dict

property variance_predictor1View on GitHub

Sample variance predictor for condition 1.

property variance_predictor2View on GitHub

Sample variance predictor for condition 2.

ExpressionModel

class kompot.ExpressionModel(n_landmarks: int | None = None, use_empirical_variance: bool = True, eps: float = 1e-08, random_state: int | None = None, batch_size: int = 500, store_arrays_on_disk: bool | None = None, disk_storage_dir: str | None = None, function_predictor: Any | None = None, obs_variance_predictor: Any | None = None, variance_predictor: Any | None = None)View on GitHub

Bases: object

Single-condition GP expression model.

Encapsulates one fitted Gaussian Process for gene expression, along with optional empirical (aleatoric) variance and sample variance components. Can be used standalone for imputation or paired inside DifferentialExpression for two-condition comparisons.

Parameters:
  • n_landmarks (int, optional) – Number of landmarks for Nystrom approximation.

  • use_empirical_variance (bool) – Whether to estimate per-gene empirical variance from GP residuals.

  • eps (float) – Small constant for numerical stability.

  • random_state (int, optional) – Random seed for landmark selection.

  • batch_size (int) – Batch size for prediction.

  • store_arrays_on_disk (bool, optional) – Whether to store large arrays on disk.

  • disk_storage_dir (str, optional) – Directory for disk-backed arrays.

  • function_predictor (callable, optional) – Pre-fitted mellon Predictor (skips fit()).

  • obs_variance_predictor (callable, optional) – Pre-fitted empirical (aleatoric) variance predictor. When provided, the internal empirical-variance computation is skipped during fit().

  • variance_predictor (callable, optional) – Pre-fitted sample variance predictor.

property cov_funcView on GitHub

Covariance function of the GP.

covariance(X, diag=True, batch_size=None, progress=True)View on GitHub

GP posterior covariance (epistemic uncertainty).

Parameters:
  • X (np.ndarray) – Evaluation points.

  • diag (bool) – If True return diagonal only.

  • batch_size (int, optional) – Override instance batch size.

  • progress (bool) – Show progress bar.

fit(X: ndarray, y: ndarray, sigma: float = 1.0, ls: float | None = None, ls_factor: float = 10.0, landmarks: ndarray | None = None, sample_indices: ndarray | None = None, sample_estimator_ls: float | None = None, allow_single_condition_variance: bool = False, **function_kwargs) ExpressionModelView on GitHub

Fit the expression GP on one condition.

Parameters:
  • X (np.ndarray) – Cell-state coordinates, shape (n_cells, n_features).

  • y (np.ndarray) – Expression matrix, shape (n_cells, n_genes).

  • sigma (float) – Noise level for the GP.

  • ls (float, optional) – Length scale. If None, estimated with ls_factor.

  • ls_factor (float) – Multiplier applied to the automatically inferred length scale.

  • landmarks (np.ndarray, optional) – Pre-computed landmarks (e.g. shared across conditions).

  • sample_indices (np.ndarray, optional) – Per-cell sample labels for biological-replicate variance.

  • sample_estimator_ls (float, optional) – Length scale override for the sample-variance estimator.

  • allow_single_condition_variance (bool) – Passed through to SampleVarianceEstimator.

  • **function_kwargs – Forwarded to mellon.FunctionEstimator.

Returns:

self, for chaining.

Return type:

ExpressionModel

property has_empirical_variance: boolView on GitHub

Whether empirical (obs) variance is available.

property has_sample_variance: boolView on GitHub

Whether sample variance is available.

property landmarksView on GitHub

Landmarks used by the GP, or None.

property lsView on GitHub

Length scale of the GP kernel (Matern component).

obs_variance(X, batch_size=None, progress=True)View on GitHub

Smoothed aleatoric noise, shape (n_points, n_genes).

Returns scalar 0 when empirical variance is disabled. When per-sample predictors are available (fitted with sample_indices), returns the mean of per-sample estimates to avoid double-counting between-sample variance.

predict(X, batch_size=None, progress=True)View on GitHub

Imputed expression, shape (n_points, n_genes).

Parameters:
  • X (np.ndarray) – Evaluation points.

  • batch_size (int, optional) – Override instance batch size.

  • progress (bool) – Show progress bar.

property predictorView on GitHub

The underlying mellon Predictor.

sample_variance(X, diag=True, batch_size=None, progress=True)View on GitHub

Biological-replicate variance.

Returns scalar 0 when no sample indices were provided.

property sigmaView on GitHub

Noise level used during fitting.

std(X, batch_size=None, progress=True)View on GitHub

Square root of total variance plus eps.

total_variance(X, diag=True, batch_size=None, progress=True)View on GitHub

Sum of GP posterior, obs_variance, and sample variance (diagonal).

SampleVariance

class kompot.SampleVarianceEstimator(eps: float = 1e-08, jit_compile: bool = True, estimator_type: str = 'function', store_arrays_on_disk: bool | None = None, disk_storage_dir: str | None = None, dask_num_workers: int | None = None)View on GitHub

Bases: object

Compute local sample variances of gene expressions or density.

This class manages the computation of empirical variance by fitting function estimators or density estimators for each group in the data and computing the variance between their predictions. Bessel’s correction is applied to the variance calculation to ensure unbiased estimation, especially important when the number of samples is small.

group_predictors

Dictionary of prediction functions for each group.

Type:

Dict

estimator_type

Type of estimator used (‘function’ for gene expression, ‘density’ for cell density).

Type:

str

disk_storage

Storage manager for offloading large arrays to disk, if enabled.

Type:

DiskStorage, optional

n_groups

Number of unique groups found during fit. Must be at least 2 for variance calculation.

Type:

int

fit(X: ndarray, Y: ndarray = None, grouping_vector: ndarray = None, min_cells: int = 2, ls_factor: float = 10.0, estimator_kwargs: Dict = None)View on GitHub

Fit estimators for each group in the data and store only their predictors.

At least 2 groups with sufficient cells (>= min_cells) are required for variance calculation. If fewer than 2 valid groups are found, a ValueError will be raised.

Parameters:
  • X (np.ndarray) – Cell states. Shape (n_cells, n_features).

  • Y (np.ndarray, optional) – Gene expression values. Shape (n_cells, n_genes). Required for function estimator, not used for density estimator.

  • grouping_vector (np.ndarray) – Vector specifying which group each cell belongs to. Shape (n_cells,).

  • min_cells (int) – Minimum number of cells for group to train an estimator. Default is 2. Groups with fewer cells will be skipped.

  • ls_factor (float, optional) – Multiplication factor to apply to length scale when it’s automatically inferred, by default 10.0. Only used when ls is not explicitly provided in estimator_kwargs.

  • estimator_kwargs (Dict, optional) – Additional arguments to pass to the estimator constructor (FunctionEstimator or DensityEstimator).

Returns:

The fitted instance.

Return type:

self

Raises:

ValueError – If fewer than 2 groups have sufficient cells to compute variance.

predict(X_new: ndarray, diag: bool = False, progress: bool = True) ndarrayView on GitHub

Predict empirical variance for new points using JAX.

This method computes the variance with Bessel’s correction (using n-1 instead of n in the denominator) to provide an unbiased estimate of the population variance. This correction is particularly important when the number of samples (groups) is small.

Parameters:
  • X_new (np.ndarray) – New cell states to predict. Shape (n_cells, n_features).

  • diag (bool, optional) – If True (default is False), compute the variance for each cell state. If False, compute the full covariance matrix between all pairs of cells.

  • progress (bool, optional) – Whether to show a progress bar during covariance computation. Default True.

Returns:

If diag=True:

For function estimators: Empirical variance for each new point. Shape (n_cells, n_genes). For density estimators: Empirical variance for each new point. Shape (n_cells, 1).

If diag=False:

For function estimators: Full covariance matrix. Shape (n_cells, n_cells, n_genes). For density estimators: Full covariance matrix. Shape (n_cells, n_cells, 1).

Return type:

np.ndarray