gauri-sharan commited on
Commit
3ef82d2
·
verified ·
1 Parent(s): e016842

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -12
app.py CHANGED
@@ -6,16 +6,20 @@ import torch
6
  from PIL import Image
7
  import os
8
  import traceback
9
- import spaces # Ensure import for GPU management
10
 
11
- # Load the Byaldi and Qwen2-VL models without using .cuda()
12
- rag_model = RAGMultiModalModel.from_pretrained("vidore/colpali")
 
 
 
 
13
  qwen_model = Qwen2VLForConditionalGeneration.from_pretrained(
14
- "Qwen/Qwen2-VL-2B-Instruct", trust_remote_code=True, torch_dtype=torch.bfloat16
15
- )
16
 
17
  # Processor for Qwen2-VL
18
- processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct", trust_remote_code=True)
19
 
20
  @spaces.GPU # Decorate the function for GPU management
21
  def ocr_and_extract(image, text_query):
@@ -52,16 +56,13 @@ def ocr_and_extract(image, text_query):
52
  text_input = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
53
  image_inputs, _ = process_vision_info(messages)
54
 
 
55
  inputs = processor(
56
  text=[text_input],
57
  images=image_inputs,
58
  padding=True,
59
  return_tensors="pt",
60
- )
61
-
62
- # Move the Qwen2-VL model and inputs to GPU
63
- qwen_model.to("cuda")
64
- inputs = {k: v.to("cuda") for k, v in inputs.items()}
65
 
66
  # Generate the output with Qwen2-VL
67
  generated_ids = qwen_model.generate(**inputs, max_new_tokens=50)
@@ -92,4 +93,4 @@ iface = gr.Interface(
92
  )
93
 
94
  # Launch the Gradio app
95
- iface.launch()
 
6
  from PIL import Image
7
  import os
8
  import traceback
9
+ import spaces
10
 
11
+ # Check if CUDA is available
12
+ device = "cuda" if torch.cuda.is_available() else "cpu"
13
+ print(f"Using device: {device}")
14
+
15
+ # Load the Byaldi and Qwen2-VL models
16
+ rag_model = RAGMultiModalModel.from_pretrained("vidore/colpali").to(device) # Move Byaldi to GPU
17
  qwen_model = Qwen2VLForConditionalGeneration.from_pretrained(
18
+ "Qwen/Qwen2-VL-7B-Instruct", trust_remote_code=True, torch_dtype=torch.bfloat16
19
+ ).to(device) # Move Qwen2-VL to GPU
20
 
21
  # Processor for Qwen2-VL
22
+ processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct", trust_remote_code=True)
23
 
24
  @spaces.GPU # Decorate the function for GPU management
25
  def ocr_and_extract(image, text_query):
 
56
  text_input = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
57
  image_inputs, _ = process_vision_info(messages)
58
 
59
+ # Move the image inputs and processor outputs to CUDA
60
  inputs = processor(
61
  text=[text_input],
62
  images=image_inputs,
63
  padding=True,
64
  return_tensors="pt",
65
+ ).to(device)
 
 
 
 
66
 
67
  # Generate the output with Qwen2-VL
68
  generated_ids = qwen_model.generate(**inputs, max_new_tokens=50)
 
93
  )
94
 
95
  # Launch the Gradio app
96
+ iface.launch()