ariG23498 HF Staff commited on
Commit
e7427b0
·
1 Parent(s): e7c21b5

type fixes

Browse files
Files changed (1) hide show
  1. app.py +11 -6
app.py CHANGED
@@ -4,9 +4,13 @@ import torch
4
  from transformers import AutoModelForImageTextToText, AutoProcessor
5
 
6
  # Load model and processor
7
- MODEL_PATH = "google/gemma-3n-E2B-it"
8
  processor = AutoProcessor.from_pretrained(MODEL_PATH)
9
- model = AutoModelForImageTextToText.from_pretrained(MODEL_PATH, torch_dtype="auto", device_map="auto")
 
 
 
 
10
 
11
  @spaces.GPU
12
  def process_inputs(image, audio):
@@ -18,10 +22,11 @@ def process_inputs(image, audio):
18
  ).to(model.device, dtype=model.dtype)
19
 
20
  # Generate text output
21
- outputs = model.generate(
22
- **inputs,
23
- max_new_tokens=256
24
- )
 
25
 
26
  # Decode and return text
27
  text = processor.batch_decode(
 
4
  from transformers import AutoModelForImageTextToText, AutoProcessor
5
 
6
  # Load model and processor
7
+ MODEL_PATH = "google/gemma-3n-E4B-it"
8
  processor = AutoProcessor.from_pretrained(MODEL_PATH)
9
+ model = AutoModelForImageTextToText.from_pretrained(
10
+ MODEL_PATH,
11
+ torch_dtype=torch.bfloat16,
12
+ device_map="cuda"
13
+ ).eval()
14
 
15
  @spaces.GPU
16
  def process_inputs(image, audio):
 
22
  ).to(model.device, dtype=model.dtype)
23
 
24
  # Generate text output
25
+ with torch.inference_mode:
26
+ outputs = model.generate(
27
+ **inputs,
28
+ max_new_tokens=256
29
+ )
30
 
31
  # Decode and return text
32
  text = processor.batch_decode(