cocktailpeanut commited on
Commit
0cb2e17
·
1 Parent(s): 2ad01cf
Files changed (1) hide show
  1. gradio_demo/app.py +39 -35
gradio_demo/app.py CHANGED
@@ -63,41 +63,6 @@ else:
63
  device = "cpu"
64
  torch_dtype = torch.float32
65
 
66
- # Load pretrained models.
67
- print("Initializing pipeline...")
68
- pipe = InstantIRPipeline.from_pretrained(
69
- sdxl_repo_id,
70
- torch_dtype=torch_dtype,
71
- )
72
-
73
- # Image prompt projector.
74
- print("Loading LQ-Adapter...")
75
- load_adapter_to_pipe(
76
- pipe,
77
- f"{instantir_path}/adapter.pt",
78
- dinov2_repo_id,
79
- )
80
-
81
- # Prepare previewer
82
- lora_alpha = pipe.prepare_previewers(instantir_path)
83
- print(f"use lora alpha {lora_alpha}")
84
- lora_alpha = pipe.prepare_previewers(lcm_repo_id, use_lcm=True)
85
- print(f"use lora alpha {lora_alpha}")
86
- pipe.to(device=device, dtype=torch_dtype)
87
- pipe.scheduler = DDPMScheduler.from_pretrained(sdxl_repo_id, subfolder="scheduler")
88
- lcm_scheduler = LCMSingleStepScheduler.from_config(pipe.scheduler.config)
89
-
90
- # Load weights.
91
- print("Loading checkpoint...")
92
- aggregator_state_dict = torch.load(
93
- f"{instantir_path}/aggregator.pt",
94
- map_location="cpu"
95
- )
96
- pipe.aggregator.load_state_dict(aggregator_state_dict, strict=True)
97
- pipe.aggregator.to(device=device, dtype=torch_dtype)
98
-
99
- print("******loaded")
100
-
101
  MAX_SEED = np.iinfo(np.int32).max
102
  MAX_IMAGE_SIZE = 1024
103
 
@@ -129,6 +94,44 @@ def show_final_preview(preview_row):
129
  def instantir_restore(
130
  lq, prompt="", steps=30, cfg_scale=7.0, guidance_end=1.0,
131
  creative_restoration=False, seed=3407, height=1024, width=1024, preview_start=0.0, progress=gr.Progress(track_tqdm=True)):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
132
  if creative_restoration:
133
  if "lcm" not in pipe.unet.active_adapters():
134
  pipe.unet.set_adapter('lcm')
@@ -177,6 +180,7 @@ def instantir_restore(
177
  for i, preview_img in enumerate(out[1]):
178
  preview_img.append(f"preview_{i}")
179
 
 
180
  gc.collect()
181
  print(f"TORCH={torch}")
182
  if torch.cuda.is_available():
 
63
  device = "cpu"
64
  torch_dtype = torch.float32
65
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
  MAX_SEED = np.iinfo(np.int32).max
67
  MAX_IMAGE_SIZE = 1024
68
 
 
94
  def instantir_restore(
95
  lq, prompt="", steps=30, cfg_scale=7.0, guidance_end=1.0,
96
  creative_restoration=False, seed=3407, height=1024, width=1024, preview_start=0.0, progress=gr.Progress(track_tqdm=True)):
97
+
98
+
99
+
100
+ # Load pretrained models.
101
+ print("Initializing pipeline...")
102
+ pipe = InstantIRPipeline.from_pretrained(
103
+ sdxl_repo_id,
104
+ torch_dtype=torch_dtype,
105
+ )
106
+
107
+ # Image prompt projector.
108
+ print("Loading LQ-Adapter...")
109
+ load_adapter_to_pipe(
110
+ pipe,
111
+ f"{instantir_path}/adapter.pt",
112
+ dinov2_repo_id,
113
+ )
114
+
115
+ # Prepare previewer
116
+ lora_alpha = pipe.prepare_previewers(instantir_path)
117
+ print(f"use lora alpha {lora_alpha}")
118
+ lora_alpha = pipe.prepare_previewers(lcm_repo_id, use_lcm=True)
119
+ print(f"use lora alpha {lora_alpha}")
120
+ pipe.to(device=device, dtype=torch_dtype)
121
+ pipe.scheduler = DDPMScheduler.from_pretrained(sdxl_repo_id, subfolder="scheduler")
122
+ lcm_scheduler = LCMSingleStepScheduler.from_config(pipe.scheduler.config)
123
+
124
+ # Load weights.
125
+ print("Loading checkpoint...")
126
+ aggregator_state_dict = torch.load(
127
+ f"{instantir_path}/aggregator.pt",
128
+ map_location="cpu"
129
+ )
130
+ pipe.aggregator.load_state_dict(aggregator_state_dict, strict=True)
131
+ pipe.aggregator.to(device=device, dtype=torch_dtype)
132
+
133
+ print("******loaded")
134
+
135
  if creative_restoration:
136
  if "lcm" not in pipe.unet.active_adapters():
137
  pipe.unet.set_adapter('lcm')
 
180
  for i, preview_img in enumerate(out[1]):
181
  preview_img.append(f"preview_{i}")
182
 
183
+ del pipe
184
  gc.collect()
185
  print(f"TORCH={torch}")
186
  if torch.cuda.is_available():