File size: 2,059 Bytes
3605bde
b84c0d1
 
 
 
 
0e3adc8
 
10271df
b84c0d1
 
 
 
0e3adc8
b84c0d1
 
0e3adc8
 
b84c0d1
 
0e3adc8
 
 
 
b84c0d1
0e3adc8
b84c0d1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7fca33e
b84c0d1
 
99efb3d
 
b84c0d1
 
 
0e3adc8
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
import os
import base64
import io
import cv2
import torch
import gradio as gr
from peft import PeftModel, PeftConfig
from transformers import AutoProcessor, PaliGemmaForConditionalGeneration
from PIL import Image
from huggingface_hub import login
 
# Step 1: Log in to Hugging Face
access_token = os.environ["HF_TOKEN"]  # Ensure your Hugging Face token is stored in an environment variable
login(token=access_token)
 
# Step 2: Setup device and load model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dtype = torch.bfloat16
 
# Load configuration and model
config = PeftConfig.from_pretrained("anushettypsl/paligemma_vqav2")
base_model = PaliGemmaForConditionalGeneration.from_pretrained("google/paligemma-3b-pt-448")
model = PeftModel.from_pretrained(base_model, "anushettypsl/paligemma_vqav2", device_map=device)
processor = AutoProcessor.from_pretrained("google/paligemma-3b-pt-448", device_map=device)
 
model.to(device)
 
# Step 3: Define prediction function
def predict(input_image, input_text):
    # Convert the uploaded image to RGB format
    input_image = input_image.convert('RGB')
 
    # Prepare the model inputs
    model_inputs = processor(text=input_text, images=input_image, return_tensors="pt").to(device)
 
    # Perform inference
    with torch.inference_mode():
        generation = model.generate(**model_inputs, max_new_tokens=100, do_sample=False)
 
    # Decode the output
    decoded_output = processor.decode(generation[0], skip_special_tokens=True)
    return decoded_output
 
# Step 4: Create the Gradio interface
interface = gr.Interface(
    fn=predict,
    inputs=[
        gr.Image(type="pil", label="Upload Image"),  # Image input
        gr.Textbox(label="Input Prompt", value="Detect whether the pathology is malignant or benign? If malignant, then detect the grade G1, G2, or G3.")  # Text input
    ],
    outputs="text",  # Text output
    title="anushettypsl/paligemma_vqav2",
    description="Upload an image to predict grade of cancer"
)
# Step 5: Launch the Gradio app
interface.launch()