ljw20180420 commited on
Commit
09948e1
·
verified ·
1 Parent(s): 8d75840

Upload pipeline.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. pipeline.py +40 -0
pipeline.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from diffusers import DiffusionPipeline
2
+ from torch.distributions import Categorical
3
+ import torch
4
+ from tqdm import tqdm
5
+
6
+ class CRISPRDiffuserPipeline(DiffusionPipeline):
7
+ def __init__(self, unet, scheduler):
8
+ super().__init__()
9
+
10
+ self.register_modules(unet=unet, scheduler=scheduler)
11
+ self.stationary_sampler1 = Categorical(probs=unet.stationary_sampler1_probs)
12
+ self.stationary_sampler2 = Categorical(probs=unet.stationary_sampler2_probs)
13
+
14
+ @torch.no_grad()
15
+ def __call__(self, batch, batch_size=1, record_path=False):
16
+ x1t = self.stationary_sampler1.sample(torch.Size([batch_size]))
17
+ x2t = self.stationary_sampler2.sample(torch.Size([batch_size]))
18
+ t = self.scheduler.step_to_time(torch.tensor([self.scheduler.config.num_train_timesteps]))
19
+ if record_path:
20
+ x1ts, x2ts, ts = [x1t], [x2t], [t]
21
+ for timestep in tqdm(self.scheduler.timesteps):
22
+ if timestep >= t:
23
+ continue
24
+ p_theta_0_logit = self.unet(
25
+ {
26
+ "x1t": x1t.to(self.unet.device),
27
+ "x2t": x2t.to(self.unet.device),
28
+ "t": t.to(self.unet.device)
29
+ },
30
+ batch["condition"].to(self.unet.device).expand(batch_size, -1, -1, -1)
31
+ )["p_theta_0_logit"].cpu()
32
+ # the scheduler automatically set t = timestep
33
+ x1t, x2t, t = self.scheduler.step(p_theta_0_logit, x1t, x2t, t, self.stationary_sampler1, self.stationary_sampler2)
34
+ if record_path:
35
+ x1ts.append(x1t)
36
+ x2ts.append(x2t)
37
+ ts.append(t)
38
+ if record_path:
39
+ return x1ts, x2ts, ts
40
+ return x1t, x2t