gokaygokay commited on
Commit
9ed763c
·
1 Parent(s): a3aef65
Files changed (1) hide show
  1. app.py +28 -71
app.py CHANGED
@@ -18,7 +18,6 @@ from PIL import Image
18
  from trellis.pipelines import TrellisImageTo3DPipeline
19
  from trellis.representations import Gaussian, MeshExtractResult
20
  from trellis.utils import render_utils, postprocessing_utils
21
- from contextlib import contextmanager
22
 
23
  huggingface_token = os.getenv("HUGGINGFACE_TOKEN")
24
  # Constants
@@ -84,23 +83,6 @@ def unpack_state(state: dict) -> Tuple[Gaussian, edict]:
84
  def get_seed(randomize_seed: bool, seed: int) -> int:
85
  return np.random.randint(0, MAX_SEED) if randomize_seed else seed
86
 
87
-
88
- # Example class-based or function-based context manager
89
- @contextmanager
90
- def pipeline_on_gpu(pipeline, device="cuda"):
91
- """
92
- Context manager that places the pipeline on GPU at enter,
93
- then on exit puts it to CPU to free VRAM.
94
- """
95
- # Move pipeline from CPU to GPU (if needed)
96
- pipeline.to(device)
97
- try:
98
- yield pipeline
99
- finally:
100
- # Move pipeline back to CPU and clear CUDA cache
101
- pipeline.to("cpu")
102
- torch.cuda.empty_cache()
103
-
104
  @spaces.GPU
105
  def generate_flux_image(
106
  prompt: str,
@@ -113,25 +95,20 @@ def generate_flux_image(
113
  lora_scale: float,
114
  progress: gr.Progress = gr.Progress(track_tqdm=True),
115
  ) -> Image.Image:
116
- """Generate image using Flux pipeline only on GPU during generation."""
117
  if randomize_seed:
118
  seed = random.randint(0, MAX_SEED)
119
  generator = torch.Generator(device=device).manual_seed(seed)
120
 
121
- # Use the context manager to keep the pipeline on GPU just while generating.
122
- with pipeline_on_gpu(flux_pipeline, device=device) as gpu_pipeline:
123
- result = gpu_pipeline(
124
- prompt=prompt,
125
- guidance_scale=guidance_scale,
126
- num_inference_steps=num_inference_steps,
127
- width=width,
128
- height=height,
129
- generator=generator,
130
- joint_attention_kwargs={"scale": lora_scale},
131
- )
132
-
133
- # Once we leave the context manager, the pipeline is moved back to CPU.
134
- image = result.images[0]
135
 
136
  # Save the generated image
137
  timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
@@ -152,48 +129,28 @@ def image_to_3d(
152
  slat_sampling_steps: int,
153
  req: gr.Request,
154
  ) -> Tuple[dict, str]:
155
- # Clear CUDA cache before starting
156
- torch.cuda.empty_cache()
157
-
158
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
159
-
160
- try:
161
- with pipeline_on_gpu(trellis_pipeline, device=device) as gpu_pipeline:
162
- outputs = gpu_pipeline.run(
163
- image,
164
- seed=seed,
165
- formats=["gaussian", "mesh"],
166
- preprocess_image=False,
167
- sparse_structure_sampler_params={
168
- "steps": ss_sampling_steps,
169
- "cfg_strength": ss_guidance_strength,
170
- },
171
- slat_sampler_params={
172
- "steps": slat_sampling_steps,
173
- "cfg_strength": slat_guidance_strength,
174
- },
175
- )
176
-
177
- # Create video while model is still on GPU
178
- video = render_utils.render_video(outputs['gaussian'][0], num_frames=120)['color']
179
- video_geo = render_utils.render_video(outputs['mesh'][0], num_frames=120)['normal']
180
- video = [np.concatenate([video[i], video_geo[i]], axis=1) for i in range(len(video))]
181
-
182
- # Pack state while tensors are still on GPU
183
- state = pack_state(outputs['gaussian'][0], outputs['mesh'][0])
184
-
185
- except Exception as e:
186
- # Ensure cleanup on error
187
- torch.cuda.empty_cache()
188
- raise e
189
-
190
- # Save video after GPU operations are complete
191
  video_path = os.path.join(user_dir, 'sample.mp4')
192
  imageio.mimsave(video_path, video, fps=15)
193
-
194
- # Final cleanup
195
  torch.cuda.empty_cache()
196
-
197
  return state, video_path
198
 
199
  @spaces.GPU(duration=90)
 
18
  from trellis.pipelines import TrellisImageTo3DPipeline
19
  from trellis.representations import Gaussian, MeshExtractResult
20
  from trellis.utils import render_utils, postprocessing_utils
 
21
 
22
  huggingface_token = os.getenv("HUGGINGFACE_TOKEN")
23
  # Constants
 
83
  def get_seed(randomize_seed: bool, seed: int) -> int:
84
  return np.random.randint(0, MAX_SEED) if randomize_seed else seed
85
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86
  @spaces.GPU
87
  def generate_flux_image(
88
  prompt: str,
 
95
  lora_scale: float,
96
  progress: gr.Progress = gr.Progress(track_tqdm=True),
97
  ) -> Image.Image:
98
+ """Generate image using Flux pipeline"""
99
  if randomize_seed:
100
  seed = random.randint(0, MAX_SEED)
101
  generator = torch.Generator(device=device).manual_seed(seed)
102
 
103
+ image = flux_pipeline(
104
+ prompt=prompt,
105
+ guidance_scale=guidance_scale,
106
+ num_inference_steps=num_inference_steps,
107
+ width=width,
108
+ height=height,
109
+ generator=generator,
110
+ joint_attention_kwargs={"scale": lora_scale},
111
+ ).images[0]
 
 
 
 
 
112
 
113
  # Save the generated image
114
  timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
 
129
  slat_sampling_steps: int,
130
  req: gr.Request,
131
  ) -> Tuple[dict, str]:
 
 
 
132
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
133
+ outputs = trellis_pipeline.run(
134
+ image,
135
+ seed=seed,
136
+ formats=["gaussian", "mesh"],
137
+ preprocess_image=False,
138
+ sparse_structure_sampler_params={
139
+ "steps": ss_sampling_steps,
140
+ "cfg_strength": ss_guidance_strength,
141
+ },
142
+ slat_sampler_params={
143
+ "steps": slat_sampling_steps,
144
+ "cfg_strength": slat_guidance_strength,
145
+ },
146
+ )
147
+ video = render_utils.render_video(outputs['gaussian'][0], num_frames=120)['color']
148
+ video_geo = render_utils.render_video(outputs['mesh'][0], num_frames=120)['normal']
149
+ video = [np.concatenate([video[i], video_geo[i]], axis=1) for i in range(len(video))]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
150
  video_path = os.path.join(user_dir, 'sample.mp4')
151
  imageio.mimsave(video_path, video, fps=15)
152
+ state = pack_state(outputs['gaussian'][0], outputs['mesh'][0])
 
153
  torch.cuda.empty_cache()
 
154
  return state, video_path
155
 
156
  @spaces.GPU(duration=90)