File size: 2,715 Bytes
3910ab9
a3b02b5
 
3910ab9
a3b02b5
70cd2c0
a3b02b5
70cd2c0
691a03c
a3b02b5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
691a03c
dc0bd2a
 
a3b02b5
 
 
 
dc0bd2a
70cd2c0
 
a3b02b5
 
 
 
 
 
 
 
 
 
 
 
 
 
3910ab9
a3b02b5
 
691a03c
a3b02b5
 
691a03c
a3b02b5
3910ab9
a3b02b5
 
 
 
 
 
 
3910ab9
 
 
70cd2c0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98

# iSEEEK
A universal approach for integrating super large-scale single-cell transcriptomes by exploring gene rankings

## An simple pipeline for single-cell analysis
```python
import torch
import gzip
import re
from tqdm import tqdm
import numpy as np
import scanpy as sc
from torch.utils.data import DataLoader, Dataset
from transformers import PreTrainedTokenizerFast, BertForMaskedLM 

class LineDataset(Dataset):
    def __init__(self, lines):
        self.lines = lines
        self.regex = re.compile(r'\-|\.')
    def __getitem__(self, i):
        return self.regex.sub('_', self.lines[i])
    def __len__(self):
        return len(self.lines)

device = "cuda" if torch.cuda.is_available() else "cpu" 
torch.set_num_threads(2)

tokenizer = PreTrainedTokenizerFast.from_pretrained("TJMUCH/transcriptome-iseeek")
model = BertForMaskedLM.from_pretrained("TJMUCH/transcriptome-iseeek").bert
model = model.to(device)
model.eval()


## Data desposited in https://huggingface.co/TJMUCH/transcriptome-iseeek/tree/main
lines = [s.strip().decode() for s in gzip.open("pbmc_ranking.txt.gz")]
labels = [s.strip().decode() for s in gzip.open("pbmc_label.txt.gz")]
labels = np.asarray(labels)


ds = LineDataset(lines)
dl = DataLoader(ds, batch_size=80)

features = []

for a in tqdm(dl, total=len(dl)):
    batch = tokenizer(a, max_length=128, truncation=True, 
               padding=True, return_tensors="pt")

    for k, v in batch.items():
        batch[k] = v.to(device)

    with torch.no_grad():
        out = model(**batch)

    f = out.last_hidden_state[:,0,:]
    features.extend(f.tolist())

features = np.stack(features)

adata = sc.AnnData(features)
adata.obs['celltype'] = labels
adata.obs.celltype = adata.obs.celltype.astype("category")
sc.pp.neighbors(adata, use_rep='X')
sc.tl.umap(adata)
sc.tl.leiden(adata)
sc.pl.umap(adata, color=['celltype','leiden'],save= "UMAP")

```

## Extract token representations 
```python

cell_counts = len(lines)
x = np.zeros((cell_counts, len(tokenizer)), dtype=np.float16)

for a in tqdm(dl, total=len(dl)):
    batch = tokenizer(a, max_length=128, truncation=True,
               padding=True, return_tensors="pt")

    for k, v in batch.items():
        batch[k] = v.to(device)

    with torch.no_grad():
        out = model(**batch)

    eos_idxs = batch.attention_mask.sum(dim=1) - 1
    f = out.last_hidden_state
    batch_size = f.shape[0]
    input_ids = batch.input_ids

    for i in range(batch_size):
        ##genes = tokenizer.batch_decode(input_ids[i])
        token_norms = [f[i][j].norm().item() for j in range(1, eos_idxs[i])]
        idxs = input_ids[i].tolist()[1:eos_idxs[i]]
        x[counter, idxs] = token_norms
        counter = counter + 1
```