ljw20180420 commited on
Commit
98e362a
·
verified ·
1 Parent(s): 54938e6

Upload pipeline.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. pipeline.py +30 -0
pipeline.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from diffusers import DiffusionPipeline
3
+ import torch.nn.functional as F
4
+ import numpy as np
5
+
6
+ class FOREcasTPipeline(DiffusionPipeline):
7
+ def __init__(self, FOREcasT_model, MAX_DEL_SIZE):
8
+ super().__init__()
9
+
10
+ self.register_modules(FOREcasT_model=FOREcasT_model)
11
+ self.MAX_DEL_SIZE = MAX_DEL_SIZE
12
+ self.lefts = np.concatenate([
13
+ np.arange(-DEL_SIZE, 1)
14
+ for DEL_SIZE in range(self.MAX_DEL_SIZE, -1, -1)
15
+ ] + [np.zeros(20, np.int64)])
16
+ self.rights = np.concatenate([
17
+ np.arange(0, DEL_SIZE + 1)
18
+ for DEL_SIZE in range(self.MAX_DEL_SIZE, -1, -1)
19
+ ] + [np.zeros(20, np.int64)])
20
+ self.inss = (self.MAX_DEL_SIZE + 2) * (self.MAX_DEL_SIZE + 1) // 2 * [""] + ["A", "C", "G", "T", "AA", "AC", "AG", "AT", "CA", "CC", "CG", "CT", "GA", "GC", "GG", "GT", "TA", "TC", "TG", "TT"]
21
+
22
+ @torch.no_grad()
23
+ def __call__(self, batch):
24
+ assert batch["feature"].shape[1] == len(self.lefts), "the possible mutation number of the input feature does not fit the pipeline"
25
+ return {
26
+ "proba": F.softmax(self.FOREcasT_model(batch["feature"].to(self.FOREcasT_model.device))["logit"], dim=-1),
27
+ "left": self.lefts,
28
+ "right": self.rights,
29
+ "ins_seq": self.inss
30
+ }