Quick Start Notebook
What This Notebook Covers¶
By the end of this notebook, you will have:
- loaded and checked a prepared spatial transcriptomics dataset
- preprocessed the data into a STEER-ready representation
- run the first training stage to learn a stable cell-state representation
- derive kinetic priors and use them in the second training stage for kinetic learning
- generated velocity-related outputs and spatial visualizations
This notebook is designed to clarify the main workflow, not to expose every optional feature.
Expected Input¶
The recommended input is an AnnData object with the following structure:
adata.layers['spliced']: requiredadata.layers['unspliced']: requiredadata.obs_namesandadata.var_names: requiredadata.obsm['X_spatial']: preferred for the spatial workflow used here
If your object only has adata.obsm['spatial'], the helper below can standardize it into X_spatial.
This quick start uses the demo dataset included in the repository.
import os
import random
import warnings
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import scanpy as sc
import scvelo as scv
import seaborn as sns
import torch
import torch.backends.cudnn as cudnn
import steer
from steer.prior.prior import PriorInferenceManager
warnings.filterwarnings("ignore")
%matplotlib inline
print("=== STEER Quick Start Environment ===")
print(f"PyTorch version: {torch.__version__}")
print(f"Scanpy version: {sc.__version__}")
print(f"scVelo version: {scv.__version__}")
print(f"STEER version: {getattr(steer, '__version__', 'dev build')}")
print("====================================")
=== STEER Quick Start Environment === PyTorch version: 2.8.0+cu128 Scanpy version: 1.10.2 scVelo version: 0.3.3 STEER version: 2.2.2 ====================================
Configuration¶
This section keeps only the small set of settings that users are most likely to change.
The main training behavior is controlled by training_mode:
"full": use STEER's built-in default training schedule without overriding epoch or early-stopping arguments"fast": use the lightweight quickstart settings for a faster demo run
class Config:
# Data
data_name = "bilinear"
input_dir = "../../tutorials/demo_data"
result_dir = "../../tutorials/results"
seed = 2026
# Core model and graph settings
expert = 3 # Temporary default; update after preprocessing with steer.estimate_expert_number
smooth_neigh = 30
spatial_neighbors = 8
npc = 30
# Training profile
training_mode = "fast" # Options: "full" or "fast"
finetune_epochs = 5000 # Used only when training_mode == "fast"
# Advanced settings
graph = "union"
corr_mode = "u"
neighbor_metric = "cosine"
use_us = True
use_filter = False
# Prior-inference settings(recommanded if expert number < 5)
fine_method = "hierarchical"
target_size = 300
direction_base = "fine_cluster"
# # Prior-inference settings(recommanded if expert number >= 5)
# fine_method = "none"
# target_size = 300
# direction_base = "expert"
cfg = Config()
INPUT_FILE = os.path.join(cfg.input_dir, f"{cfg.data_name}.h5ad")
RESULT_PATH = os.path.join(cfg.result_dir, f"{cfg.data_name}_quickstart")
os.makedirs(RESULT_PATH, exist_ok=True)
def get_stage1_training_overrides(mode: str) -> dict:
if mode == "full":
return {}
if mode == "fast":
return {
"expert_mode": "slim",
"pretrain_epochs": 500,
"cluster_epochs": 200,
}
raise ValueError(f"Unsupported training_mode: {mode}")
def get_stage2_training_overrides(mode: str, finetune_epochs: int) -> dict:
if mode == "full":
return {}
if mode == "fast":
return {
"expert_mode": "slim",
"pretrain_epochs": 500,
"cluster_epochs": 200,
"velo_batch_size": 2048,
"MIN_IMPRO": 0.001,
"PATIENCE": 200,
"num_epochs": finetune_epochs,
}
raise ValueError(f"Unsupported training_mode: {mode}")
print(f"Input file: {INPUT_FILE}")
print(f"Output directory: {RESULT_PATH}")
print(f"Training mode: {cfg.training_mode}")
Input file: ../../tutorials/demo_data/bilinear.h5ad Output directory: ../../tutorials/results/bilinear_quickstart Training mode: fast
How To Choose Key Parameters¶
A few parameters play the key roles in the model behavior:
training_mode: choose"full"to use STEER's default training schedule, or"fast"to use the lightweight quickstart schedule for a faster run.smooth_neigh: the number of neighbors used for smoothing. If the data quality is poor or noisy, increasing this value can help stabilize the signal. In practice,30or100often works well in many settings.spatial_neighbors: the number of neighbors used in the spatial graph. For relatively coarse-resolution data,8is often enough. For high-resolution data, increasing it to around30can improve robustness to local noise.use_filter: whether to filter genes before the kinetic-learning stage.Falsemeans the model keeps the full gene set after basic QC. This is often convenient for an initial run. If enabled, STEER uses its prior-inference strategy to retain genes that are more suitable for kinetic inference.expert: the number of kinetic experts. In this quick start, we keep a temporary default incfg.expertand then estimate a better value from the processed embedding immediately afterBasic Preprocessing.
Reproducibility And Device Setup¶
To keep the tutorial reproducible, we fix the random seed before loading data and training the model. The notebook will automatically use a GPU if one is available.
def setup_seed(seed: int) -> None:
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
cudnn.deterministic = True
cudnn.benchmark = False
setup_seed(cfg.seed)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
Using device: cuda:0
Load The Demo Dataset¶
We begin by loading the bundled demo dataset and checking whether the required layers and spatial coordinates are present. These checks are useful when you later replace the demo file with your own data.
adata = sc.read_h5ad(INPUT_FILE)
adata = steer.prepare_spatial_adata(
adata,
obs_copy_map={},
na_clear_keys=[],
celltype_key=None,
remove_celltypes=[],
)
print(adata)
print("\nAvailable layers:", list(adata.layers.keys()))
print("Available obsm keys:", list(adata.obsm.keys()))
print(f"Number of cells: {adata.n_obs}")
print(f"Number of genes: {adata.n_vars}")
required_layers = ["spliced", "unspliced"]
missing_layers = [layer for layer in required_layers if layer not in adata.layers]
if missing_layers:
raise ValueError(f"Missing required layers: {missing_layers}")
if "X_spatial" not in adata.obsm:
raise ValueError("Spatial coordinates were not standardized into adata.obsm['X_spatial'].")
AnnData object with n_obs × n_vars = 3000 × 500
obs: 'true_time'
var: 'true_alpha', 'true_beta', 'true_gamma', 'true_ton', 'true_toff', 'repressive_genes'
obsm: 'X_spatial', 'true_spatial_velocity', 'velocity_spatial'
layers: 'spliced', 'true_rho', 'unspliced'
Available layers: ['spliced', 'true_rho', 'unspliced']
Available obsm keys: ['X_spatial', 'true_spatial_velocity', 'velocity_spatial']
Number of cells: 3000
Number of genes: 500
Basic Preprocessing¶
scvelo.pp.filter_and_normalizeprepares the input for downstream velocity modeling. After that,steer.preprocess_anndata_spatialbuilds the STEER-ready dataframe, adjacency matrix, and processedAnnDataobject used by the graph-based model.In real data, we usually use
scv.pp.filter_and_normalize(adata, min_shared_counts=20, n_top_genes=2000)to perform QC. And here we have not used filter parameters due to the simulation settings
scv.pp.filter_and_normalize(adata, min_shared_counts=None, n_top_genes=None, enforce=True)
df, adjacency_matrix, adata = steer.preprocess_anndata_spatial(
adata,
npc=cfg.npc,
NUM_AD_NEIGH=30,
spatial_neighbors=cfg.spatial_neighbors,
SMOOTH_NEIGH=cfg.smooth_neigh,
moments_adj=True,
combine_mode=cfg.graph,
neighbor_metric=cfg.neighbor_metric,
spatial_key="X_spatial",
use_us=cfg.use_us,
)
print(df.head())
print("\nAdjacency matrix shape:", adjacency_matrix.shape)
Normalized count data: X, spliced, unspliced. Logarithmized X.
/nvme/users/liuzhy/miniconda3/envs/steer_dev/lib/python3.10/site-packages/scvelo/preprocessing/utils.py:705: DeprecationWarning: `log1p` is deprecated since scVelo v0.3.0 and will be removed in a future version. Please use `log1p` from `scanpy.pp` instead. log1p(adata)
computing moments based on connectivities
finished (0:00:00) --> added
'Ms' and 'Mu', moments of un/spliced abundances (adata.layers)
cellID gene_name unsplice splice orig_unsplice orig_splice Mu \
0 0 0 0.120764 0.151685 0.0 0.283035 0.038810
1 1 0 0.224775 0.086827 0.0 0.044122 0.072236
2 2 0 0.412140 0.225812 0.0 0.093380 0.132449
3 3 0 0.529100 0.161109 0.0 0.000000 0.170036
4 4 0 0.512205 0.116279 0.0 0.024226 0.164607
Ms
0 0.012870
1 0.007367
2 0.019159
3 0.013669
4 0.009866
Adjacency matrix shape: (3000, 3000)
Estimate The Expert Number¶
Expert-number selection should use the processed representation from Basic Preprocessing. With the default quick-start setting use_us=True, the relevant embedding is stored in adata.obsm["X_pca_combined"].
The helper below keeps the workflow in Python, runs mclust through STEER, saves the entropy elbow figure, and returns a recommended expert value.
from IPython.display import Image, display
expert_result = steer.estimate_expert_number(
adata,
candidate_range=range(2, 6),
used_obsm="X_pca_combined",
save_path=RESULT_PATH,
)
cfg.expert = expert_result["recommended_expert"]
print(f"Selected expert number: {cfg.expert}")
if expert_result["figure_path"] is not None:
print(f"Entropy elbow figure: {expert_result['figure_path']}")
display(Image(filename=expert_result["figure_path"]))
fitting ... | | 0%
/nvme/users/liuzhy/miniconda3/envs/steer_dev/lib/python3.10/site-packages/rpy2/robjects/numpy2ri.py:241: DeprecationWarning: The global conversion available with activate() is deprecated and will be removed in the next major release. Use a local converter.
warnings.warn('The global conversion available with activate() '
|======================================================================| 100% Selected expert number: 3 Entropy elbow figure: ../../tutorials/results/bilinear_quickstart/Cluster_Number_Selection.png
Prepare PyTorch Geometric Inputs¶
STEER trains on a graph representation of the dataset. This step packages the processed dataframe and adjacency matrix into the PyTorch Geometric format expected by the training functions.
dataset = steer.preload_datasets_all_genes_anndata(df=df, MODEL_MODE="pretrain", adata=adata)
pyg_data = steer.create_pyg_data(dataset, adjacency_matrix, normalize=True)
print(pyg_data)
Data(x=[3000, 1000], edge_index=[2, 143781], type_features=[3000, 500], orig_features=[3000, 1000], cell_ids=[3000], adj=[3000, 3000])
Step 1: First-Stage Training¶
The first training stage learns an initial latent representation, expert assignments, and cell-state structure. This stage provides the geometric foundation used by the later kinetic-learning stage.
- In
fullmode, the notebook does not override training epochs or early-stopping controls, so STEER uses its internal defaults. - In
fastmode, the notebook applies the lightweight quickstart settings for a shorter demo run.
stage1_kwargs = dict(
device=device,
device2=device,
pyg_data=pyg_data,
MODEL_MODE="pretrain",
adata=adata,
NUM_LOSS_NEIGH=30,
corr_mode=cfg.corr_mode,
max_n_cluster=cfg.expert,
path=RESULT_PATH,
)
stage1_kwargs.update(get_stage1_training_overrides(cfg.training_mode))
result_adata = steer.model_training_share_neighbor_adata(**stage1_kwargs)
Epoch 0, Loss 11.95889663696289 Epoch 50, Loss 0.21063464879989624 Epoch 100, Loss 0.12496109306812286 Epoch 150, Loss 0.09942024946212769 Epoch 200, Loss 0.09262166917324066 Epoch 250, Loss 0.08339523524045944 Epoch 300, Loss 0.07576797902584076 Epoch 350, Loss 0.07269712537527084 Epoch 400, Loss 0.07057160139083862 Epoch 450, Loss 0.06858766078948975 Epoch 500, Loss 0.9779857397079468 Epoch 550, Loss 0.27658528089523315 Epoch 600, Loss 0.22063004970550537 Epoch 650, Loss 0.2154701054096222
Optional: Initial Clustering With mclust¶
If R and rpy2 are available, STEER can use mclust for an initial clustering step.
If that environment is not available, this notebook will continue without it.
try:
result_adata = steer.mclust_R(result_adata, num_cluster=cfg.expert)
print("Initial clustering with mclust completed.")
except Exception as e:
print("mclust step skipped.")
print(f"Reason: {e}")
torch.cuda.empty_cache()
fitting ... | | 0%
/nvme/users/liuzhy/miniconda3/envs/steer_dev/lib/python3.10/site-packages/rpy2/robjects/numpy2ri.py:241: DeprecationWarning: The global conversion available with activate() is deprecated and will be removed in the next major release. Use a local converter.
warnings.warn('The global conversion available with activate() '
|======================================================================| 100% Initial clustering with mclust completed.
Step 2: Prior Inference¶
This stage derives kinetic priors from the representation learned above. It includes defining fine clusters, optionally filtering genes, and estimating convexity-based directional information.
prior_manager = PriorInferenceManager(result_adata, df, RESULT_PATH, seed=cfg.seed)
prior_manager.task1_define_fine_clusters(method=cfg.fine_method, target_size=cfg.target_size)
prior_manager.task2_filter_genes(based_on="expert", keep_ngene=1000, use_filter=cfg.use_filter)
prior_manager.task3_calc_convexity(based_on=cfg.direction_base)
result_adata, df_updated, fine_clus_vec_np = prior_manager.finalize_for_training()
print("Prior inference completed.")
--- Task 1: Generating Fine Clusters (Method: hierarchical) --- -> Hierarchical K-Means: Target size 300 -> Generated 11 hierarchical micro-clusters. --- Task 2: Gene Filtering Skipped (Calculated on expert_cluster for reference) --- --- Task 3: Calculating Direction (Based on: FINE_CLUSTER) --- Starting parallel processing with n_jobs=-1...
[Parallel(n_jobs=-1)]: Using backend LokyBackend with 128 concurrent workers. [Parallel(n_jobs=-1)]: Done 32 tasks | elapsed: 9.0s [Parallel(n_jobs=-1)]: Done 194 tasks | elapsed: 15.8s [Parallel(n_jobs=-1)]: Done 346 out of 500 | elapsed: 16.9s remaining: 7.5s [Parallel(n_jobs=-1)]: Done 447 out of 500 | elapsed: 17.9s remaining: 2.1s [Parallel(n_jobs=-1)]: Done 500 out of 500 | elapsed: 18.6s finished
Aggregating results... Done. --- Finalizing: Restoring Expert labels & Preparing Fine Cluster Vector --- Prior inference completed.
Step 3: Build The Kinetic-Learning Dataset¶
After prior inference, we subset to velocity genes and transfer the learned outputs needed for the next training stage. We then rerun preprocessing on the dataset used for kinetic learning.
prior_adata = result_adata.copy()
raw_adata = sc.read_h5ad(INPUT_FILE)
raw_adata = steer.prepare_spatial_adata(
raw_adata,
obs_copy_map={},
na_clear_keys=[],
celltype_key=None,
remove_celltypes=[],
)
if "X_pca" in raw_adata.obsm:
del raw_adata.obsm["X_pca"]
scv.pp.filter_and_normalize(raw_adata, min_shared_counts=None, n_top_genes=None, enforce=True)
assert all(prior_adata.obs_names == raw_adata.obs_names), "Observation names are not aligned!"
assert all(prior_adata.var_names == raw_adata.var_names), "Variable names are not aligned!"
raw_adata.layers["pred_cell_type"] = prior_adata.layers["pred_cell_type"]
raw_adata.obsm["X_pre_embed"] = prior_adata.obsm["X_pre_embed"]
raw_adata.obs["pred_cluster"] = prior_adata.obs["pred_cluster"].astype(int)
velo_adata = raw_adata[:, prior_adata.var["is_velocity_gene"]].copy()
df_fine, adjacency_matrix_fine, velo_adata = steer.preprocess_anndata_spatial(
velo_adata,
npc=cfg.npc,
NUM_AD_NEIGH=30,
spatial_neighbors=cfg.spatial_neighbors,
SMOOTH_NEIGH=cfg.smooth_neigh,
moments_adj=True,
neighbor_metric=cfg.neighbor_metric,
combine_mode=cfg.graph,
spatial_key="X_spatial",
use_us=cfg.use_us,
)
dataset_fine = steer.preload_datasets_all_genes_anndata(df=df_fine, MODEL_MODE="whole", adata=velo_adata)
pyg_data_fine = steer.create_pyg_data(dataset_fine, adjacency_matrix_fine, normalize=True)
pyg_data_fine.fine_clus_vec = torch.tensor(fine_clus_vec_np, dtype=torch.long, device=device)
print(velo_adata)
print("Velocity-gene subset shape:", velo_adata.shape)
Normalized count data: X, spliced, unspliced. Logarithmized X.
/nvme/users/liuzhy/miniconda3/envs/steer_dev/lib/python3.10/site-packages/scvelo/preprocessing/utils.py:705: DeprecationWarning: `log1p` is deprecated since scVelo v0.3.0 and will be removed in a future version. Please use `log1p` from `scanpy.pp` instead. log1p(adata)
computing moments based on connectivities
finished (0:00:00) --> added
'Ms' and 'Mu', moments of un/spliced abundances (adata.layers)
AnnData object with n_obs × n_vars = 3000 × 500
obs: 'true_time', 'initial_size_unspliced', 'initial_size_spliced', 'initial_size', 'n_counts', 'pred_cluster'
var: 'true_alpha', 'true_beta', 'true_gamma', 'true_ton', 'true_toff', 'repressive_genes'
uns: 'log1p', 'neighbors'
obsm: 'X_spatial', 'true_spatial_velocity', 'velocity_spatial', 'X_pre_embed', 'X_pca_combined', 'X_pca_moments'
layers: 'spliced', 'true_rho', 'unspliced', 'pred_cell_type', 'Ms', 'Mu', 'scale_Mu', 'scale_Ms'
obsp: 'distances', 'connectivities'
Velocity-gene subset shape: (3000, 500)
Step 4: Kinetic Learning With Guided Priors¶
This is the main kinetic-learning stage in STEER. The model now learns kinetic behavior using the prior information estimated above, which helps disentangle kinetic regimes and improve downstream interpretability.
- In
fullmode, the notebook leaves the training schedule unchanged and uses STEER defaults. - In
fastmode, it uses the quickstart overrides:expert_mode="slim", shorter pretraining/clustering,velo_batch_size=512, and more aggressive stopping.
stage2_kwargs = dict(
device=device,
device2=device,
pyg_data=pyg_data_fine,
MODEL_MODE="whole",
adata=velo_adata,
NUM_LOSS_NEIGH=30,
max_n_cluster=cfg.expert,
corr_mode=cfg.corr_mode,
path=RESULT_PATH,
)
stage2_kwargs.update(get_stage2_training_overrides(cfg.training_mode, cfg.finetune_epochs))
velo_adata = steer.model_training_share_neighbor_adata(**stage2_kwargs)
print("Kinetic-learning stage completed.")
Using Fine Cluster Vector for Correlation Loss. Epoch 0, Loss 11.95889663696289 Epoch 50, Loss 0.206576868891716 Epoch 100, Loss 0.12776394188404083 Epoch 150, Loss 0.10484662652015686 Epoch 200, Loss 0.09822432696819305 Epoch 250, Loss 0.08645522594451904 Epoch 300, Loss 0.08201576769351959 Epoch 350, Loss 0.07943274825811386 Epoch 400, Loss 0.0772695243358612 Epoch 450, Loss 0.0754842758178711 Epoch 500, Loss 0.9878708124160767 Epoch 550, Loss 0.28525209426879883 Epoch 600, Loss 0.22735220193862915 Epoch 650, Loss 0.2173364758491516 Epoch 700, Loss 0.21371614933013916 Epoch 750, Loss 1.163996696472168 GATE: 0.06181614100933075, Cluster: 0.1502506136894226, Time_cor: 0.9397410154342651, Time_smooth: 0.012188863009214401 Epoch 800, Loss 1.1581414937973022 GATE: 0.0610065832734108, Cluster: 0.15012121200561523, Time_cor: 0.9349130392074585, Time_smooth: 0.01210066583007574 Epoch 850, Loss 1.1542119979858398 GATE: 0.06045851856470108, Cluster: 0.14969772100448608, Time_cor: 0.9320486783981323, Time_smooth: 0.012007107958197594 Epoch 900, Loss 1.149121642112732 GATE: 0.05955719202756882, Cluster: 0.1487891674041748, Time_cor: 0.9288840293884277, Time_smooth: 0.011891219764947891 Epoch 950, Loss 1.3314369916915894 GATE: 0.058998532593250275, Cluster: 0.14852118492126465, Velo: 0.196233332157135, Time_cor: 0.9159371852874756, Time_smooth: 0.01174671296030283 Epoch 1000, Loss 1.3090060949325562 GATE: 0.05848994106054306, Cluster: 0.1483914852142334, Velo: 0.1879415512084961, Time_cor: 0.9025628566741943, Time_smooth: 0.011620225384831429 Epoch 1050, Loss 1.2884113788604736 GATE: 0.05808946117758751, Cluster: 0.1482553482055664, Velo: 0.18237382173538208, Time_cor: 0.8881785273551941, Time_smooth: 0.011514183133840561 Epoch 1100, Loss 1.3204329013824463 GATE: 0.07244014739990234, Cluster: 0.16101497411727905, Velo: 0.18328851461410522, Time_cor: 0.8914892077445984, Time_smooth: 0.012200173921883106 Epoch 1150, Loss 1.1595971584320068 GATE: 0.04624493420124054, Cluster: 0.15119731426239014, Velo: 0.0720226913690567, Time_cor: 0.8778508901596069, Time_smooth: 0.012281244620680809 Epoch 1200, Loss 1.130675196647644 GATE: 0.044268593192100525, Cluster: 0.14930784702301025, Velo: 0.061543483287096024, Time_cor: 0.8632649183273315, Time_smooth: 0.012290384620428085 Epoch 1250, Loss 1.123887062072754 GATE: 0.04401750862598419, Cluster: 0.1490243673324585, Velo: 0.058712393045425415, Time_cor: 0.8598681688308716, Time_smooth: 0.012264602817595005 Epoch 1300, Loss 1.1060290336608887 GATE: 0.04307917505502701, Cluster: 0.14821666479110718, Velo: 0.056314632296562195, Time_cor: 0.8462601900100708, Time_smooth: 0.012158355675637722 Epoch 1350, Loss 1.0933388471603394 GATE: 0.04260875657200813, Cluster: 0.14791011810302734, Velo: 0.05481821671128273, Time_cor: 0.8359219431877136, Time_smooth: 0.012079795822501183 Epoch 1400, Loss 1.087620735168457 GATE: 0.042391419410705566, Cluster: 0.1477823257446289, Velo: 0.05409429594874382, Time_cor: 0.8313101530075073, Time_smooth: 0.012042495422065258 Epoch 1450, Loss 1.0871096849441528 GATE: 0.042351506650447845, Cluster: 0.1477588415145874, Velo: 0.054526153951883316, Time_cor: 0.830438494682312, Time_smooth: 0.01203470304608345 Epoch 1500, Loss 1.0741822719573975 GATE: 0.041764359921216965, Cluster: 0.1474255919456482, Velo: 0.054060496389865875, Time_cor: 0.8190149068832397, Time_smooth: 0.01191682554781437 Epoch 1550, Loss 1.0607815980911255 GATE: 0.04131277650594711, Cluster: 0.1472354531288147, Velo: 0.053054239600896835, Time_cor: 0.8073596954345703, Time_smooth: 0.011819465085864067 Epoch 1600, Loss 1.0486787557601929 GATE: 0.04101086035370827, Cluster: 0.1471463441848755, Velo: 0.05103940889239311, Time_cor: 0.7977465987205505, Time_smooth: 0.011735519394278526 Epoch 1650, Loss 1.0419728755950928 GATE: 0.040785275399684906, Cluster: 0.14709651470184326, Velo: 0.05206312984228134, Time_cor: 0.7903566360473633, Time_smooth: 0.01167125441133976 Epoch 1700, Loss 1.0360301733016968 GATE: 0.04064088314771652, Cluster: 0.1470746397972107, Velo: 0.051118191331624985, Time_cor: 0.7855709195137024, Time_smooth: 0.011625513434410095 Epoch 1750, Loss 1.0329833030700684 GATE: 0.04055962711572647, Cluster: 0.1470600962638855, Velo: 0.050875190645456314, Time_cor: 0.7828885912895203, Time_smooth: 0.011599761433899403 Epoch 1800, Loss 1.0298640727996826 GATE: 0.04052554816007614, Cluster: 0.14705431461334229, Velo: 0.04892308637499809, Time_cor: 0.7817731499671936, Time_smooth: 0.011587963439524174 Epoch 1850, Loss 1.0309568643569946 GATE: 0.04051697254180908, Cluster: 0.14705288410186768, Velo: 0.05029568821191788, Time_cor: 0.7815061211585999, Time_smooth: 0.011585204862058163 Epoch 1900, Loss 1.0224909782409668 GATE: 0.04092542082071304, Cluster: 0.14713001251220703, Velo: 0.04788734018802643, Time_cor: 0.7748563289642334, Time_smooth: 0.01169195119291544 Epoch 1950, Loss 1.0106061697006226 GATE: 0.04013804346323013, Cluster: 0.14698201417922974, Velo: 0.049345917999744415, Time_cor: 0.762607753276825, Time_smooth: 0.01153241191059351 Epoch 2000, Loss 0.9922252297401428 GATE: 0.032689232379198074, Cluster: 0.1469213366508484, Velo: 0.047895632684230804, Time_cor: 0.753287672996521, Time_smooth: 0.01143142394721508 Epoch 2050, Loss 0.9825901985168457 GATE: 0.032434385269880295, Cluster: 0.14689236879348755, Velo: 0.04713989794254303, Time_cor: 0.7448050379753113, Time_smooth: 0.011318524368107319 Epoch 2100, Loss 0.9761142134666443 GATE: 0.03220432996749878, Cluster: 0.14687460660934448, Velo: 0.048839021474123, Time_cor: 0.7369674444198608, Time_smooth: 0.011228823103010654 Epoch 2150, Loss 0.9680013060569763 GATE: 0.032173365354537964, Cluster: 0.14683574438095093, Velo: 0.047264546155929565, Time_cor: 0.730591893196106, Time_smooth: 0.01113576628267765 Epoch 2200, Loss 0.9605516195297241 GATE: 0.03188781440258026, Cluster: 0.14680343866348267, Velo: 0.04612479358911514, Time_cor: 0.7246648073196411, Time_smooth: 0.011070708744227886 Epoch 2250, Loss 0.9556068181991577 GATE: 0.03176131471991539, Cluster: 0.14678949117660522, Velo: 0.04592098668217659, Time_cor: 0.7201306223869324, Time_smooth: 0.011004367843270302 Epoch 2300, Loss 0.9515222311019897 GATE: 0.03166891634464264, Cluster: 0.14677715301513672, Velo: 0.04585229605436325, Time_cor: 0.7162793874740601, Time_smooth: 0.010944456793367863 Epoch 2350, Loss 0.9493870735168457 GATE: 0.03159583732485771, Cluster: 0.14676666259765625, Velo: 0.046799372881650925, Time_cor: 0.7133249640464783, Time_smooth: 0.01090025994926691 Epoch 2400, Loss 0.9455398321151733 GATE: 0.03154213726520538, Cluster: 0.14676082134246826, Velo: 0.045224402099847794, Time_cor: 0.7111465930938721, Time_smooth: 0.010865837335586548 Epoch 2450, Loss 0.9455854892730713 GATE: 0.03150124102830887, Cluster: 0.14675641059875488, Velo: 0.04685981571674347, Time_cor: 0.7096279263496399, Time_smooth: 0.010840123519301414 Epoch 2500, Loss 0.9444072842597961 GATE: 0.031474873423576355, Cluster: 0.14675360918045044, Velo: 0.04670054838061333, Time_cor: 0.7086542248725891, Time_smooth: 0.010824046097695827 Epoch 2550, Loss 0.9425204396247864 GATE: 0.031460732221603394, Cluster: 0.14675217866897583, Velo: 0.04538949579000473, Time_cor: 0.7081037163734436, Time_smooth: 0.010814286768436432 Epoch 2600, Loss 0.9410333633422852 GATE: 0.031454261392354965, Cluster: 0.14675045013427734, Velo: 0.04417264088988304, Time_cor: 0.7078462243080139, Time_smooth: 0.01080978661775589 Epoch 2650, Loss 0.9421994686126709 GATE: 0.031451717019081116, Cluster: 0.14675027132034302, Velo: 0.04544931277632713, Time_cor: 0.7077405452728271, Time_smooth: 0.010807638987898827 Epoch 2700, Loss 0.9440739750862122 GATE: 0.032667774707078934, Cluster: 0.14711689949035645, Velo: 0.04457830265164375, Time_cor: 0.7083759903907776, Time_smooth: 0.011334993876516819 Epoch 2750, Loss 0.9331907629966736 GATE: 0.03168310225009918, Cluster: 0.1469295620918274, Velo: 0.044948186725378036, Time_cor: 0.6985397934913635, Time_smooth: 0.011090107262134552 Epoch 2800, Loss 0.9233309030532837 GATE: 0.031434185802936554, Cluster: 0.1469402313232422, Velo: 0.044561173766851425, Time_cor: 0.6894679665565491, Time_smooth: 0.010927312076091766 Epoch 2850, Loss 0.9147927761077881 GATE: 0.031200572848320007, Cluster: 0.14680826663970947, Velo: 0.0428389273583889, Time_cor: 0.6831514239311218, Time_smooth: 0.010793594643473625 Epoch 2900, Loss 0.9089385867118835 GATE: 0.03127501159906387, Cluster: 0.14677554368972778, Velo: 0.0428733304142952, Time_cor: 0.6773347854614258, Time_smooth: 0.010679889470338821 Epoch 2950, Loss 0.9042980670928955 GATE: 0.03127112612128258, Cluster: 0.1467674970626831, Velo: 0.04437128081917763, Time_cor: 0.6712539792060852, Time_smooth: 0.010634168982505798 Epoch 3000, Loss 0.8986451029777527 GATE: 0.03108879178762436, Cluster: 0.14674019813537598, Velo: 0.04407918080687523, Time_cor: 0.6661425232887268, Time_smooth: 0.010594426654279232 Epoch 3050, Loss 0.891643762588501 GATE: 0.030886992812156677, Cluster: 0.14673590660095215, Velo: 0.04320862516760826, Time_cor: 0.6603138446807861, Time_smooth: 0.010498415678739548 Epoch 3100, Loss 0.8881244659423828 GATE: 0.030776800587773323, Cluster: 0.14671015739440918, Velo: 0.044445112347602844, Time_cor: 0.6558070778846741, Time_smooth: 0.010385311208665371 Epoch 3150, Loss 0.8831722736358643 GATE: 0.030680399388074875, Cluster: 0.14669984579086304, Velo: 0.04388052970170975, Time_cor: 0.6516215205192566, Time_smooth: 0.010289977304637432 Epoch 3200, Loss 0.8773002624511719 GATE: 0.030665114521980286, Cluster: 0.14671188592910767, Velo: 0.041703399270772934, Time_cor: 0.6480000019073486, Time_smooth: 0.010219900868833065 Epoch 3250, Loss 0.8748617172241211 GATE: 0.030620533972978592, Cluster: 0.14670217037200928, Velo: 0.042977284640073776, Time_cor: 0.6443996429443359, Time_smooth: 0.010162119753658772 Epoch 3300, Loss 0.8722267150878906 GATE: 0.030460013076663017, Cluster: 0.14666295051574707, Velo: 0.043807461857795715, Time_cor: 0.6412233114242554, Time_smooth: 0.010072917677462101 Epoch 3350, Loss 0.8684630990028381 GATE: 0.03048897534608841, Cluster: 0.14666688442230225, Velo: 0.042851611971855164, Time_cor: 0.6384283900260925, Time_smooth: 0.010027232579886913 Epoch 3400, Loss 0.8649924397468567 GATE: 0.030342595651745796, Cluster: 0.14634853601455688, Velo: 0.042394738644361496, Time_cor: 0.6359468698501587, Time_smooth: 0.009959714487195015 Epoch 3450, Loss 0.8616535067558289 GATE: 0.03021620772778988, Cluster: 0.14604580402374268, Velo: 0.04208618029952049, Time_cor: 0.6334227323532104, Time_smooth: 0.009882569313049316 Epoch 3500, Loss 0.8599284887313843 GATE: 0.030169254168868065, Cluster: 0.1455816626548767, Velo: 0.04284399002790451, Time_cor: 0.6315131783485413, Time_smooth: 0.009820420295000076 Epoch 3550, Loss 0.8525925278663635 GATE: 0.03016761876642704, Cluster: 0.13809210062026978, Velo: 0.04466725140810013, Time_cor: 0.6298826932907104, Time_smooth: 0.009782825596630573 Epoch 3600, Loss 0.8482539653778076 GATE: 0.0300874225795269, Cluster: 0.13805395364761353, Velo: 0.04209020361304283, Time_cor: 0.628295361995697, Time_smooth: 0.009727011434733868 Epoch 3650, Loss 0.8469855189323425 GATE: 0.030046746134757996, Cluster: 0.1380438208580017, Velo: 0.04220215603709221, Time_cor: 0.6270180940628052, Time_smooth: 0.00967473816126585 Epoch 3700, Loss 0.8445344567298889 GATE: 0.030008919537067413, Cluster: 0.13803136348724365, Velo: 0.040941618382930756, Time_cor: 0.6259233355522156, Time_smooth: 0.009629195556044579 Epoch 3750, Loss 0.8458107709884644 GATE: 0.02997562289237976, Cluster: 0.1380259394645691, Velo: 0.04320782050490379, Time_cor: 0.625011146068573, Time_smooth: 0.009590266272425652 Epoch 3800, Loss 0.8449268937110901 GATE: 0.02994590811431408, Cluster: 0.13801974058151245, Velo: 0.04313353821635246, Time_cor: 0.6242707371711731, Time_smooth: 0.009556976146996021 Epoch 3850, Loss 0.8420872092247009 GATE: 0.02992311678826809, Cluster: 0.13801246881484985, Velo: 0.040952593088150024, Time_cor: 0.6236673593521118, Time_smooth: 0.009531664662063122 Epoch 3900, Loss 0.8432311415672302 GATE: 0.029908809810876846, Cluster: 0.13800621032714844, Velo: 0.042607299983501434, Time_cor: 0.6231988668441772, Time_smooth: 0.00950995646417141 Epoch 3950, Loss 0.8413946032524109 GATE: 0.029894351959228516, Cluster: 0.13800263404846191, Velo: 0.04117833077907562, Time_cor: 0.6228256225585938, Time_smooth: 0.009493631310760975 Epoch 4000, Loss 0.8417630791664124 GATE: 0.029884520918130875, Cluster: 0.13800007104873657, Velo: 0.041848570108413696, Time_cor: 0.622549295425415, Time_smooth: 0.009480660781264305 Epoch 4050, Loss 0.8413592576980591 GATE: 0.029877904802560806, Cluster: 0.1379990577697754, Velo: 0.04165920242667198, Time_cor: 0.6223522424697876, Time_smooth: 0.009470801800489426 Epoch 4100, Loss 0.8397413492202759 GATE: 0.029872644692659378, Cluster: 0.13799744844436646, Velo: 0.04018951207399368, Time_cor: 0.6222174763679504, Time_smooth: 0.00946428719907999 Epoch 4150, Loss 0.8404432535171509 GATE: 0.02986801788210869, Cluster: 0.1379980444908142, Velo: 0.04098602756857872, Time_cor: 0.6221311688423157, Time_smooth: 0.009459976106882095 Epoch 4200, Loss 0.8416957259178162 GATE: 0.029865887016057968, Cluster: 0.13799750804901123, Velo: 0.04229716211557388, Time_cor: 0.6220777034759521, Time_smooth: 0.009457466192543507 Early stopping triggered at epoch 4217. Best loss was 0.8400254249572754 Loading best model state from memory... Kinetic-learning stage completed.
Step 5: Spatial Velocity Visualization¶
The kinetic-learning stage already returns a cleaned output object with normalized velocity layers. We first construct the velocity graph in the refined latent space and visualize the velocity field in spatial coordinates.
result_adata = velo_adata.copy()
sc.pp.neighbors(result_adata, use_rep="X_refine_embed", n_neighbors=100)
steer.velocity_graph(result_adata, vkey="pred_vs_norm", xkey="model_Ms")
print("obs columns:")
print(sorted(result_adata.obs.columns.tolist()))
print("\nobsm keys:")
print(sorted(result_adata.obsm.keys()))
print("\nlayers:")
print(sorted(result_adata.layers.keys()))
computing velocity graph (using 1/128 cores)
0%| | 0/3000 [00:00<?, ?cells/s]
finished (0:00:05) --> added
'pred_vs_norm_graph', sparse matrix with cosine correlations (adata.uns)
obs columns:
['Expert', 'Expert Weight', 'Pred Time', 'initial_size', 'initial_size_spliced', 'initial_size_unspliced', 'n_counts', 'pred_vs_norm_self_transition', 'pretrain_cluster', 'true_time']
obsm keys:
['X_alpha', 'X_beta', 'X_gamma', 'X_para', 'X_para_t', 'X_pca_combined', 'X_pca_moments', 'X_pre_embed', 'X_refine_embed', 'X_refine_embed_t', 'X_spatial', 'cluster_matrix', 'true_spatial_velocity', 'velocity_spatial']
layers:
['final_recon_s', 'final_recon_u', 'init_regulate_state', 'model_Ms', 'model_Mu', 'orig_s', 'orig_u', 'pred_time_layer', 'pred_vs', 'pred_vs_norm', 'pred_vu', 'pred_vu_norm', 'recon_alpha', 'recon_alpha_norm', 'recon_beta', 'recon_gamma', 'recon_gamma_norm', 'regulate_state', 'scale_Ms', 'scale_Mu', 'true_rho']
scv.settings.figdir = RESULT_PATH
scv.set_figure_params(style="scvelo", dpi=150, figsize=(5, 4), transparent=True)
if "X_spatial" not in result_adata.obsm and "spatial" in result_adata.obsm:
result_adata.obsm["X_spatial"] = result_adata.obsm["spatial"].astype(float)
elif "X_spatial" in result_adata.obsm:
result_adata.obsm["X_spatial"] = result_adata.obsm["X_spatial"].astype(float)
scv.pl.velocity_embedding_stream(
result_adata,
basis="spatial",
vkey="pred_vs_norm",
color=["Expert", "Pred Time"],
title=["Expert Assignment", "Latent Time"],
show=True,
)
computing velocity embedding
finished (0:00:00) --> added
'pred_vs_norm_spatial', embedded velocity vectors (adata.obsm)
Optional: PCA/UMAP Of The Learned Latent Embedding¶
X_refine_embed is the learned latent representation returned by the kinetic-learning stage. If you want a 2D view of that latent space, compute it explicitly here.
Set USE_PCA_BEFORE_UMAP = True if you want to apply PCA before building the UMAP graph.
USE_PCA_BEFORE_UMAP = True
result_adata = steer.compute_embedding_umap(
result_adata,
embedding_key="X_refine_embed",
output_key="X_umap_refine_embed",
pca_output_key="X_pca_refine_embed",
neighbors_key="refine_embed_neighbors",
n_neighbors=100,
use_pca=USE_PCA_BEFORE_UMAP,
n_pcs=cfg.npc,
random_state=cfg.seed,
)
Optional: Latent-Space Visualization¶
After computing X_pca_refine_embed/X_umap_refine_embed, you can inspect the learned latent structure in 2D and compare it with the spatial view above.
scv.pl.velocity_embedding_stream(
result_adata,
basis="X_pca_refine_embed",
vkey="pred_vs_norm",
color=["Expert", "Pred Time"],
title=["Expert Assignment", "Latent Time"],
show=True,
)
computing velocity embedding
finished (0:00:00) --> added
'pred_vs_norm_pca_refine_embed', embedded velocity vectors (adata.obsm)
Common Notes¶
training_mode="full"reproduces the behavior of using STEER's default training schedule by not overriding epoch or early-stopping arguments.training_mode="fast"applies the lightweight quickstart schedule and is better suited for demos or quick checks.- If
mclustis unavailable, the notebook can still run, but the optional clustering step will be skipped. - If you encounter installation issues, check the package versions for PyTorch, PyG,
torch-scatter, andtorch-sparsefirst. - If you want to use your own data, the easiest path is to replace the demo
.h5adwith a file that already containsspliced,unspliced, and spatial coordinates.