Model Classes¶
The core model classes perform statistical computations independently of
AnnData. They follow a scikit-learn-style API with .fit() and
.predict() methods, and are useful when you need direct control over the
analysis process or are working with custom data formats.
DifferentialAbundance¶
- class kompot.DifferentialAbundance(log_fold_change_threshold: float = 1.0, ptp_threshold: float = 0.05, n_landmarks: int | None = None, use_sample_variance: bool | None = None, eps: float = 1e-12, jit_compile: bool = False, density_predictor1: Any | None = None, density_predictor2: Any | None = None, variance_predictor1: Any | None = None, variance_predictor2: Any | None = None, random_state: int | None = None, batch_size: int | None = None)View on GitHub
Bases:
objectCompute differential abundance between two conditions.
This class analyzes the differences in cell density between two conditions (e.g., control to treatment) using density estimation and fold change analysis.
The analysis can be performed with synchronized parameters between conditions by setting sync_parameters=True in the fit method, which ensures consistent density estimation across both conditions.
- log_density_condition1
Log density values for the first condition.
- Type:
np.ndarray
- log_density_condition2
Log density values for the second condition.
- Type:
np.ndarray
- log_fold_change
Log fold change between conditions (condition2 - condition1).
- Type:
np.ndarray
- log_fold_change_uncertainty
Uncertainty in the log fold change estimates.
- Type:
np.ndarray
- log_fold_change_zscore
Z-scores for the log fold changes.
- Type:
np.ndarray
- log_fold_change_ptp
PTP (Posterior Tail Probability) for the log fold changes. The PTP is the significance measure similar to p-value.
- Type:
np.ndarray
- log_fold_change_direction
Direction of change (‘up’, ‘down’, or ‘neutral’) based on thresholds.
- Type:
np.ndarray
- fit(X_condition1, X_condition2, sync_parameters=False, \*\*density_kwargs)View on GitHub
Fit density estimators for both conditions, optionally with synchronized parameters.
- predict(X_new)View on GitHub
Predict log density and log fold change for new points.
- fit(X_condition1: ndarray, X_condition2: ndarray, landmarks: ndarray | None = None, ls_factor: float = 10.0, condition1_sample_indices: ndarray | None = None, condition2_sample_indices: ndarray | None = None, sample_estimator_ls: float | None = None, sync_parameters: bool = False, allow_single_condition_variance: bool = False, **density_kwargs)View on GitHub
Fit density estimators for both conditions.
This method only creates the estimators and does not compute fold changes. Call predict() to compute fold changes on any set of points.
- Parameters:
X_condition1 (np.ndarray) – Cell states for the first condition. Shape (n_cells, n_features).
X_condition2 (np.ndarray) – Cell states for the second condition. Shape (n_cells, n_features).
landmarks (np.ndarray, optional) – Pre-computed landmarks to use. If provided, n_landmarks will be ignored. Shape (n_landmarks, n_features).
ls_factor (float, optional) – Multiplication factor to apply to length scale when it’s automatically inferred, by default 10.0. Only used when ls is not explicitly provided in density_kwargs.
condition1_sample_indices (np.ndarray, optional) – Sample indices for first condition. Used for sample variance estimation. Unique values in this array define different sample groups.
condition2_sample_indices (np.ndarray, optional) – Sample indices for second condition. Used for sample variance estimation. Unique values in this array define different sample groups.
sample_estimator_ls (float, optional) – Length scale for the sample-specific variance estimators. If None, will use the same value as ls or it will be estimated, by default None.
sync_parameters (bool, optional) – Whether to synchronize model parameters (d, mu, ls) between both conditions using the combined dataset. When True, parameters are computed once from the combined data to ensure models for both conditions use identical parameter values. This is especially important for consistent density estimation across conditions. Default is False.
**density_kwargs (dict) – Additional arguments to pass to the DensityEstimator.
- Returns:
The fitted instance.
- Return type:
self
- predict(X_new: ndarray, log_fold_change_threshold: float | None = None, ptp_threshold: float | None = None, progress: bool = True) Dict[str, ndarray]View on GitHub
Predict log density and log fold change for new points.
This method computes all fold changes and related metrics. It uses internal batching for efficient computation with large datasets.
- Parameters:
X_new (np.ndarray) – New cell states to predict. Shape (n_cells, n_features).
log_fold_change_threshold (float, optional) – Threshold for considering a log fold change significant. If None, uses the threshold specified during initialization.
ptp_threshold (float, optional) – Threshold for considering a PTP (Posterior Tail Probability) significant. If None, uses the threshold specified during initialization.
progress (bool, optional) – Whether to show progress bars for operations, by default True.
- Returns:
Dictionary containing the predictions: - ‘log_density_condition1’: Log density for condition 1 - ‘log_density_condition2’: Log density for condition 2 - ‘log_fold_change’: Log fold change between conditions - ‘log_fold_change_uncertainty’: Uncertainty in the log fold change - ‘log_fold_change_zscore’: Z-scores for the log fold change - ‘neg_log10_fold_change_ptp’: Negative log10 PTP (Posterior Tail Probability) for the log fold change - ‘log_fold_change_direction’: Direction of change (‘up’, ‘down’, or ‘neutral’)
- Return type:
dict
DifferentialExpression¶
- class kompot.DifferentialExpression(n_landmarks: int | None = None, use_sample_variance: bool | None = None, use_empirical_variance: bool = True, eps: float = 1e-08, jit_compile: bool = False, function_predictor1: Any | None = None, function_predictor2: Any | None = None, variance_predictor1: Any | None = None, variance_predictor2: Any | None = None, obs_variance_predictor1: Any | None = None, obs_variance_predictor2: Any | None = None, random_state: int | None = None, batch_size: int = 500, store_arrays_on_disk: bool | None = None, disk_storage_dir: str | None = None, max_memory_ratio: float = 0.8, model1: ExpressionModel | None = None, model2: ExpressionModel | None = None)View on GitHub
Bases:
objectCompute differential expression between two conditions.
This class analyzes the differences in gene expression between two conditions (e.g., control to treatment) using imputation, Mahalanobis distance, and log fold change analysis.
- function_predictor1View on GitHub
Function predictor for condition 1.
- Type:
Callable
- function_predictor2View on GitHub
Function predictor for condition 2.
- Type:
Callable
- variance_predictor1View on GitHub
Variance predictor for condition 1. If provided, will be used for uncertainty calculation.
- Type:
Callable, optional
- variance_predictor2View on GitHub
Variance predictor for condition 2. If provided, will be used for uncertainty calculation.
- Type:
Callable, optional
- mahalanobis_distances
Mahalanobis distances for each gene.
- Type:
np.ndarray
- compute_fdr(null_mahalanobis, threshold=0.05, gene_names=None)View on GitHub
Compute FDR for the last
predict()using external null distances.- Parameters:
null_mahalanobis (np.ndarray) – Null Mahalanobis distances.
threshold (float) – FDR threshold for the
is_decolumn.gene_names (list of str, optional) – Gene names for the DataFrame index.
- Returns:
DataFrame with columns:
mahalanobis,pvalue,local_fdr,tail_fdr,is_de.- Return type:
pd.DataFrame
- compute_mahalanobis_distances(X: ndarray, fold_change=None, use_landmarks: bool = True, landmarks_override: ndarray | None = None, progress: bool = True) ndarrayView on GitHub
Compute Mahalanobis distances for each gene using efficient matrix preparation and batching.
- Parameters:
X (np.ndarray) – Cell states. Shape (n_cells, n_features).
fold_change (np.ndarray, optional) – Pre-computed fold change matrix. If None, will compute it. Shape (n_cells, n_genes).
use_landmarks (bool, optional) – Whether to use landmarks for covariance calculation if available, by default True.
landmarks_override (np.ndarray, optional) – Explicitly provided landmarks to use instead of automatically detected ones, by default None.
progress (bool, optional) – Whether to show tqdm.auto progress bars during Mahalanobis distance computation. When True, displays progress bars for gene-wise operations. When False, progress bars are disabled. Default is True.
- Returns:
Array of Mahalanobis distances for each gene.
- Return type:
np.ndarray
- property empirical_variance_predictor1View on GitHub
Empirical (obs) variance predictor for condition 1.
- property empirical_variance_predictor2View on GitHub
Empirical (obs) variance predictor for condition 2.
- property expression_estimator_condition1View on GitHub
Mellon FunctionEstimator for condition 1.
- property expression_estimator_condition2View on GitHub
Mellon FunctionEstimator for condition 2.
- fit(X_condition1: ndarray, y_condition1: ndarray, X_condition2: ndarray, y_condition2: ndarray, sigma: float = 1.0, ls: float | None = None, ls_factor: float = 10.0, landmarks: ndarray | None = None, sample_estimator_ls: float | None = None, condition1_sample_indices: ndarray | None = None, condition2_sample_indices: ndarray | None = None, allow_single_condition_variance: bool = False, **function_kwargs)View on GitHub
Fit function estimators for both conditions.
This method only creates the estimators and does not compute fold changes. Call predict() to compute fold changes on any set of points.
- Parameters:
X_condition1 (np.ndarray) – Cell states for the first condition. Shape (n_cells1, n_features).
y_condition1 (np.ndarray) – Gene expression values for the first condition. Shape (n_cells1, n_genes).
X_condition2 (np.ndarray) – Cell states for the second condition. Shape (n_cells2, n_features).
y_condition2 (np.ndarray) – Gene expression values for the second condition. Shape (n_cells2, n_genes).
sigma (float, optional) – Noise level for function estimator, by default 1.0.
ls (float, optional) – Length scale for the GP kernel. If None, it will be estimated, by default None.
ls_factor (float, optional) – Multiplication factor to apply to length scale when it’s automatically inferred, by default 10.0. Only used when ls is None.
landmarks (np.ndarray, optional) – Pre-computed landmarks to use. If provided, n_landmarks will be ignored. Shape (n_landmarks, n_features).
sample_estimator_ls (float, optional) – Length scale for the sample-specific variance estimators. If None, will use the same value as ls or it will be estimated, by default None.
condition1_sample_indices (np.ndarray, optional) – Sample indices for first condition. Used for sample variance estimation. Unique values in this array define different sample groups.
condition2_sample_indices (np.ndarray, optional) – Sample indices for second condition. Used for sample variance estimation. Unique values in this array define different sample groups.
**function_kwargs (dict) – Additional arguments to pass to the FunctionEstimator.
- Returns:
The fitted instance.
- Return type:
self
- property function_predictor1View on GitHub
Function predictor for condition 1.
- property function_predictor2View on GitHub
Function predictor for condition 2.
- predict(X_new: ndarray, compute_mahalanobis: bool = False, progress: bool = True, use_landmarks: bool = True, landmarks_override: ndarray | None = None) Dict[str, ndarray]View on GitHub
Predict gene expression and differential metrics for new points.
This method computes fold changes and related metrics for the provided points. It uses internal batching for efficient computation with large datasets.
- Parameters:
X_new (np.ndarray) – New cell states. Shape (n_cells, n_features).
compute_mahalanobis (bool, optional) – Whether to compute Mahalanobis distances. This can be computationally expensive, so it’s optional in the predict method. Default is False.
progress (bool, optional) – Whether to show tqdm.auto progress bars during computation. When True, displays progress bars for all batch processing operations including prediction, uncertainty computation, and Mahalanobis distance calculations. When False, all progress bars are disabled. Default is True.
use_landmarks (bool, optional) – Whether to use landmarks for Mahalanobis distance calculation if available, by default True. Setting to False will force computation using all provided points, which can be more accurate for small datasets or subsets.
landmarks_override (np.ndarray, optional) – Explicitly provided landmarks to use instead of the ones from the fitted model. Shape (n_landmarks, n_features). Used when custom landmarks are needed for a specific prediction, such as when analyzing a subset of data.
- Returns:
Dictionary containing the predictions: - ‘condition1_smoothed’: Smoothed expression for condition 1 - ‘condition2_smoothed’: Smoothed expression for condition 2 - ‘condition1_std’: Posterior standard deviation for condition 1 - ‘condition2_std’: Posterior standard deviation for condition 2 - ‘fold_change’: Fold change between conditions - ‘mean_log_fold_change’: Mean log fold change across all cells - ‘mahalanobis_distances’: Only if compute_mahalanobis is True
- Return type:
dict
- property variance_predictor1View on GitHub
Sample variance predictor for condition 1.
- property variance_predictor2View on GitHub
Sample variance predictor for condition 2.
ExpressionModel¶
- class kompot.ExpressionModel(n_landmarks: int | None = None, use_empirical_variance: bool = True, eps: float = 1e-08, random_state: int | None = None, batch_size: int = 500, store_arrays_on_disk: bool | None = None, disk_storage_dir: str | None = None, function_predictor: Any | None = None, obs_variance_predictor: Any | None = None, variance_predictor: Any | None = None)View on GitHub
Bases:
objectSingle-condition GP expression model.
Encapsulates one fitted Gaussian Process for gene expression, along with optional empirical (aleatoric) variance and sample variance components. Can be used standalone for imputation or paired inside
DifferentialExpressionfor two-condition comparisons.- Parameters:
n_landmarks (int, optional) – Number of landmarks for Nystrom approximation.
use_empirical_variance (bool) – Whether to estimate per-gene empirical variance from GP residuals.
eps (float) – Small constant for numerical stability.
random_state (int, optional) – Random seed for landmark selection.
batch_size (int) – Batch size for prediction.
store_arrays_on_disk (bool, optional) – Whether to store large arrays on disk.
disk_storage_dir (str, optional) – Directory for disk-backed arrays.
function_predictor (callable, optional) – Pre-fitted mellon Predictor (skips
fit()).obs_variance_predictor (callable, optional) – Pre-fitted empirical (aleatoric) variance predictor. When provided, the internal empirical-variance computation is skipped during
fit().variance_predictor (callable, optional) – Pre-fitted sample variance predictor.
- property cov_funcView on GitHub
Covariance function of the GP.
- covariance(X, diag=True, batch_size=None, progress=True)View on GitHub
GP posterior covariance (epistemic uncertainty).
- Parameters:
X (np.ndarray) – Evaluation points.
diag (bool) – If True return diagonal only.
batch_size (int, optional) – Override instance batch size.
progress (bool) – Show progress bar.
- fit(X: ndarray, y: ndarray, sigma: float = 1.0, ls: float | None = None, ls_factor: float = 10.0, landmarks: ndarray | None = None, sample_indices: ndarray | None = None, sample_estimator_ls: float | None = None, allow_single_condition_variance: bool = False, **function_kwargs) ExpressionModelView on GitHub
Fit the expression GP on one condition.
- Parameters:
X (np.ndarray) – Cell-state coordinates, shape
(n_cells, n_features).y (np.ndarray) – Expression matrix, shape
(n_cells, n_genes).sigma (float) – Noise level for the GP.
ls (float, optional) – Length scale. If None, estimated with
ls_factor.ls_factor (float) – Multiplier applied to the automatically inferred length scale.
landmarks (np.ndarray, optional) – Pre-computed landmarks (e.g. shared across conditions).
sample_indices (np.ndarray, optional) – Per-cell sample labels for biological-replicate variance.
sample_estimator_ls (float, optional) – Length scale override for the sample-variance estimator.
allow_single_condition_variance (bool) – Passed through to
SampleVarianceEstimator.**function_kwargs – Forwarded to
mellon.FunctionEstimator.
- Returns:
self, for chaining.- Return type:
ExpressionModel
- property has_empirical_variance: boolView on GitHub
Whether empirical (obs) variance is available.
- property has_sample_variance: boolView on GitHub
Whether sample variance is available.
- property landmarksView on GitHub
Landmarks used by the GP, or None.
- property lsView on GitHub
Length scale of the GP kernel (Matern component).
- obs_variance(X, batch_size=None, progress=True)View on GitHub
Smoothed aleatoric noise, shape
(n_points, n_genes).Returns scalar
0when empirical variance is disabled. When per-sample predictors are available (fitted withsample_indices), returns the mean of per-sample estimates to avoid double-counting between-sample variance.
- predict(X, batch_size=None, progress=True)View on GitHub
Imputed expression, shape
(n_points, n_genes).- Parameters:
X (np.ndarray) – Evaluation points.
batch_size (int, optional) – Override instance batch size.
progress (bool) – Show progress bar.
- property predictorView on GitHub
The underlying mellon Predictor.
- sample_variance(X, diag=True, batch_size=None, progress=True)View on GitHub
Biological-replicate variance.
Returns scalar
0when no sample indices were provided.
- property sigmaView on GitHub
Noise level used during fitting.
- std(X, batch_size=None, progress=True)View on GitHub
Square root of total variance plus eps.
- total_variance(X, diag=True, batch_size=None, progress=True)View on GitHub
Sum of GP posterior, obs_variance, and sample variance (diagonal).
SampleVariance¶
- class kompot.SampleVarianceEstimator(eps: float = 1e-08, jit_compile: bool = True, estimator_type: str = 'function', store_arrays_on_disk: bool | None = None, disk_storage_dir: str | None = None, dask_num_workers: int | None = None)View on GitHub
Bases:
objectCompute local sample variances of gene expressions or density.
This class manages the computation of empirical variance by fitting function estimators or density estimators for each group in the data and computing the variance between their predictions. Bessel’s correction is applied to the variance calculation to ensure unbiased estimation, especially important when the number of samples is small.
- group_predictors
Dictionary of prediction functions for each group.
- Type:
Dict
- estimator_type
Type of estimator used (‘function’ for gene expression, ‘density’ for cell density).
- Type:
str
- disk_storage
Storage manager for offloading large arrays to disk, if enabled.
- Type:
DiskStorage, optional
- n_groups
Number of unique groups found during fit. Must be at least 2 for variance calculation.
- Type:
int
- fit(X: ndarray, Y: ndarray = None, grouping_vector: ndarray = None, min_cells: int = 2, ls_factor: float = 10.0, estimator_kwargs: Dict = None)View on GitHub
Fit estimators for each group in the data and store only their predictors.
At least 2 groups with sufficient cells (>= min_cells) are required for variance calculation. If fewer than 2 valid groups are found, a ValueError will be raised.
- Parameters:
X (np.ndarray) – Cell states. Shape (n_cells, n_features).
Y (np.ndarray, optional) – Gene expression values. Shape (n_cells, n_genes). Required for function estimator, not used for density estimator.
grouping_vector (np.ndarray) – Vector specifying which group each cell belongs to. Shape (n_cells,).
min_cells (int) – Minimum number of cells for group to train an estimator. Default is 2. Groups with fewer cells will be skipped.
ls_factor (float, optional) – Multiplication factor to apply to length scale when it’s automatically inferred, by default 10.0. Only used when ls is not explicitly provided in estimator_kwargs.
estimator_kwargs (Dict, optional) – Additional arguments to pass to the estimator constructor (FunctionEstimator or DensityEstimator).
- Returns:
The fitted instance.
- Return type:
self
- Raises:
ValueError – If fewer than 2 groups have sufficient cells to compute variance.
- predict(X_new: ndarray, diag: bool = False, progress: bool = True) ndarrayView on GitHub
Predict empirical variance for new points using JAX.
This method computes the variance with Bessel’s correction (using n-1 instead of n in the denominator) to provide an unbiased estimate of the population variance. This correction is particularly important when the number of samples (groups) is small.
- Parameters:
X_new (np.ndarray) – New cell states to predict. Shape (n_cells, n_features).
diag (bool, optional) – If True (default is False), compute the variance for each cell state. If False, compute the full covariance matrix between all pairs of cells.
progress (bool, optional) – Whether to show a progress bar during covariance computation. Default True.
- Returns:
- If diag=True:
For function estimators: Empirical variance for each new point. Shape (n_cells, n_genes). For density estimators: Empirical variance for each new point. Shape (n_cells, 1).
- If diag=False:
For function estimators: Full covariance matrix. Shape (n_cells, n_cells, n_genes). For density estimators: Full covariance matrix. Shape (n_cells, n_cells, 1).
- Return type:
np.ndarray