1- import anndata
21import numpy as np
32import zarr
3+ from anndata import AnnData
44
55from .constants import cell_categories_attrs
66
77
8- def write_gene_counts (path , adata : anndata . AnnData ):
8+ def write_gene_counts (path : str , adata : AnnData ) -> None :
99 counts = adata .layers ["counts" ]
1010
1111 feature_keys = list (adata .var_names ) + ["Total transcripts" ]
@@ -54,7 +54,9 @@ def write_gene_counts(path, adata: anndata.AnnData):
5454 cells_group .array ("indptr" , indptr , dtype = "uint32" , chunks = indptr .shape )
5555
5656
57- def add_group (root : zarr .Group , index : int , values : np .ndarray , categories : list [str ]):
57+ def _write_categorical_column (
58+ root : zarr .Group , index : int , values : np .ndarray , categories : list [str ]
59+ ) -> None :
5860 group = root .create_group (index )
5961 values_indices = [np .where (values == cat )[0 ] for cat in categories ]
6062 values_cum_len = np .cumsum ([len (indices ) for indices in values_indices ])
@@ -66,23 +68,23 @@ def add_group(root: zarr.Group, index: int, values: np.ndarray, categories: list
6668 group .array ("indptr" , indptr , dtype = "uint32" , chunks = (len (indptr ),))
6769
6870
69- def write_cell_categories (path : str , adata : anndata . AnnData ):
70- categorical_columns = [
71- name for name , cat in adata . obs . dtypes . items () if cat == "category"
72- ]
71+ def write_cell_categories (path : str , adata : AnnData ) -> None :
72+ cat_columns = [name for name , cat in adata . obs . dtypes . items () if cat == "category" ]
73+
74+ print ( f"Saving { len ( cat_columns ) } cell categories: { ', ' . join ( cat_columns ) } " )
7375
7476 ATTRS = cell_categories_attrs ()
75- ATTRS ["number_groupings" ] = len (categorical_columns )
77+ ATTRS ["number_groupings" ] = len (cat_columns )
7678
7779 with zarr .ZipStore (path , mode = "w" ) as store :
7880 g = zarr .group (store = store )
7981 cell_groups = g .create_group ("cell_groups" )
8082
81- for i , name in enumerate (categorical_columns ):
83+ for i , name in enumerate (cat_columns ):
8284 categories = list (adata .obs [name ].cat .categories )
8385 ATTRS ["grouping_names" ].append (name )
8486 ATTRS ["group_names" ].append (categories )
8587
86- add_group (cell_groups , i , adata .obs [name ], categories )
88+ _write_categorical_column (cell_groups , i , adata .obs [name ], categories )
8789
8890 cell_groups .attrs .put (ATTRS )
0 commit comments