thomasgauthier commited on
Commit
fc91aa0
1 Parent(s): bdf9962

ZeroGPU fixes

Browse files
Files changed (4) hide show
  1. app.py +6 -6
  2. gradio_interface.py +3 -2
  3. image_generator.py +10 -3
  4. model_loader.py +2 -0
app.py CHANGED
@@ -1,16 +1,16 @@
1
- import torch
2
  import spaces
 
3
  from model_loader import load_model_and_processor
4
  from image_generator import process_and_generate
5
  from gradio_interface import create_gradio_interface
6
 
7
  if __name__ == "__main__":
8
- # Set the model path
9
- model_path = "deepseek-ai/Janus-1.3B"
 
 
10
 
11
- # Load the model and processor
12
- vl_gpt, vl_chat_processor = load_model_and_processor(model_path)
13
 
14
  # Create and launch the Gradio interface
15
- demo = create_gradio_interface(vl_gpt, vl_chat_processor, process_and_generate)
16
  demo.launch(allowed_paths=["/"])
 
 
1
  import spaces
2
+ import torch
3
  from model_loader import load_model_and_processor
4
  from image_generator import process_and_generate
5
  from gradio_interface import create_gradio_interface
6
 
7
  if __name__ == "__main__":
8
+ import subprocess
9
+ subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
10
+
11
+
12
 
 
 
13
 
14
  # Create and launch the Gradio interface
15
+ demo = create_gradio_interface(process_and_generate)
16
  demo.launch(allowed_paths=["/"])
gradio_interface.py CHANGED
@@ -1,9 +1,10 @@
1
  import gradio as gr
2
  from PIL import Image
3
 
4
- def create_gradio_interface(vl_gpt, vl_chat_processor, process_and_generate):
 
5
  def gradio_process_and_generate(input_image, prompt, num_images, cfg_weight):
6
- return process_and_generate(vl_gpt, vl_chat_processor, input_image, prompt, num_images, cfg_weight)
7
 
8
  explanation = """Janus 1.3B uses a differerent visual encoder for understanding and generation.
9
 
 
1
  import gradio as gr
2
  from PIL import Image
3
 
4
+
5
+ def create_gradio_interface(process_and_generate):
6
  def gradio_process_and_generate(input_image, prompt, num_images, cfg_weight):
7
+ return process_and_generate(input_image, prompt, num_images, cfg_weight)
8
 
9
  explanation = """Janus 1.3B uses a differerent visual encoder for understanding and generation.
10
 
image_generator.py CHANGED
@@ -3,9 +3,10 @@ import PIL.Image
3
  import torch
4
  import numpy as np
5
  from janus.utils.io import load_pil_images
 
6
  from janus.models import MultiModalityCausalLM, VLChatProcessor
7
  from functools import lru_cache
8
-
9
 
10
  def prepare_classifier_free_guidance_input(input_embeds, vl_chat_processor, mmgpt, batch_size=16):
11
  uncond_input_ids = torch.full((1, input_embeds.shape[1]),
@@ -26,7 +27,6 @@ def prepare_classifier_free_guidance_input(input_embeds, vl_chat_processor, mmgp
26
 
27
  return combined_input_embeds
28
 
29
- @spaces.GPU
30
  @torch.inference_mode()
31
  def generate(
32
  mmgpt: MultiModalityCausalLM,
@@ -83,7 +83,14 @@ def get_start_tag_embed(vl_gpt, vl_chat_processor):
83
  vl_chat_processor.tokenizer.encode(vl_chat_processor.image_start_tag, add_special_tokens=False, return_tensors="pt").to(vl_gpt.device)
84
  )
85
 
86
- def process_and_generate(vl_gpt, vl_chat_processor, input_image, prompt, num_images=4, cfg_weight=5):
 
 
 
 
 
 
 
87
  start_tag_embed = get_start_tag_embed(vl_gpt, vl_chat_processor)
88
 
89
  nl = '\n'
 
3
  import torch
4
  import numpy as np
5
  from janus.utils.io import load_pil_images
6
+ from model_loader import load_model_and_processor
7
  from janus.models import MultiModalityCausalLM, VLChatProcessor
8
  from functools import lru_cache
9
+ import spaces
10
 
11
  def prepare_classifier_free_guidance_input(input_embeds, vl_chat_processor, mmgpt, batch_size=16):
12
  uncond_input_ids = torch.full((1, input_embeds.shape[1]),
 
27
 
28
  return combined_input_embeds
29
 
 
30
  @torch.inference_mode()
31
  def generate(
32
  mmgpt: MultiModalityCausalLM,
 
83
  vl_chat_processor.tokenizer.encode(vl_chat_processor.image_start_tag, add_special_tokens=False, return_tensors="pt").to(vl_gpt.device)
84
  )
85
 
86
+ @spaces.GPU
87
+ def process_and_generate(input_image, prompt, num_images=4, cfg_weight=5):
88
+ # Set the model path
89
+ model_path = "deepseek-ai/Janus-1.3B"
90
+
91
+ # Load the model and processor
92
+ vl_gpt, vl_chat_processor = load_model_and_processor(model_path)
93
+
94
  start_tag_embed = get_start_tag_embed(vl_gpt, vl_chat_processor)
95
 
96
  nl = '\n'
model_loader.py CHANGED
@@ -1,7 +1,9 @@
1
  import torch
2
  from transformers import AutoModelForCausalLM
3
  from janus.models import MultiModalityCausalLM, VLChatProcessor
 
4
 
 
5
  def load_model_and_processor(model_path):
6
  vl_chat_processor: VLChatProcessor = VLChatProcessor.from_pretrained(model_path)
7
  tokenizer = vl_chat_processor.tokenizer
 
1
  import torch
2
  from transformers import AutoModelForCausalLM
3
  from janus.models import MultiModalityCausalLM, VLChatProcessor
4
+ import spaces
5
 
6
+ @spaces.GPU
7
  def load_model_and_processor(model_path):
8
  vl_chat_processor: VLChatProcessor = VLChatProcessor.from_pretrained(model_path)
9
  tokenizer = vl_chat_processor.tokenizer