Model API
These notes cover the main stable model-facing objects that matter for users extending the package in Python.
GMMModuleVAE
Defined in src/bsvae/models/gmvae.py.
GMMModuleVAE(
n_features: int,
n_latent: int,
n_modules: int,
hidden_dims: list[int] | None = None,
dropout: float = 0.1,
use_batch_norm: bool = True,
sigma_min: float = 0.3,
normalize_input: bool = False,
)
The model is trained on feature profiles, so n_features here refers to the profile length, which is the number of samples in the expression matrix.
Common Methods
encode(x) -> (mu, logvar)forward(x) -> recon_x, mu, logvar, z, gammaget_gamma(x) -> gammaget_hard_assignments(x) -> argmax(gamma)
Tensor Shapes
x:(batch, n_samples)mu,logvar,z:(batch, n_latent)gamma:(batch, n_modules)
Losses
src/bsvae/models/losses.py exposes the main training losses:
GMMVAELossWarmupLoss
GMMVAELoss combines reconstruction with GMM-aware KL and optional auxiliary losses such as separation, balance, hierarchical consistency, and correlation-preservation terms.