girishwangikar commited on
Commit
a4c76b3
·
verified ·
1 Parent(s): 4fda22c

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +95 -0
app.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from PIL import Image
4
+ from transformers import ViltProcessor, ViltForQuestionAnswering
5
+ import torch.nn.functional as F
6
+ from torchvision.models import resnet50
7
+ import torchvision.transforms as transforms
8
+
9
+ # Load the ViLT model and processor
10
+ processor = ViltProcessor.from_pretrained("dandelin/vilt-b32-finetuned-vqa")
11
+ model = ViltForQuestionAnswering.from_pretrained("dandelin/vilt-b32-finetuned-vqa")
12
+
13
+ # Load pre-trained ResNet model
14
+ resnet50_model = resnet50(pretrained=True)
15
+ resnet50_model.eval()
16
+
17
+ # Simplified list of common objects
18
+ common_objects = ['person', 'animal', 'vehicle', 'furniture', 'electronic device', 'food', 'plant', 'building', 'clothing', 'sports equipment']
19
+
20
+ def get_image_features(image, model):
21
+ transform = transforms.Compose([
22
+ transforms.Resize(256),
23
+ transforms.CenterCrop(224),
24
+ transforms.ToTensor(),
25
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
26
+ ])
27
+ img_tensor = transform(image).unsqueeze(0)
28
+ with torch.no_grad():
29
+ features = model(img_tensor)
30
+ return features
31
+
32
+ def suggest_questions(image):
33
+ features = get_image_features(image, resnet50_model)
34
+ _, predicted = features.max(1)
35
+ class_name = common_objects[predicted.item() % len(common_objects)]
36
+
37
+ suggested_questions = [
38
+ f"What is the main object in this image?",
39
+ f"Is there a {class_name} in this picture?",
40
+ "What colors are prominent in this image?",
41
+ "What is the setting or background of this image?",
42
+ "Are there any people in this image?"
43
+ ]
44
+ return suggested_questions
45
+
46
+ def predict(image, question):
47
+ image = Image.fromarray(image.astype('uint8'), 'RGB')
48
+ encoding = processor(image, question, return_tensors="pt")
49
+
50
+ with torch.no_grad():
51
+ outputs = model(**encoding)
52
+ logits = outputs.logits
53
+ probs = F.softmax(logits, dim=-1)
54
+
55
+ # Get top 5 answers and their probabilities
56
+ top_5_probs, top_5_indices = probs.topk(5)
57
+
58
+ answers = []
59
+ for prob, idx in zip(top_5_probs[0], top_5_indices[0]):
60
+ answer = model.config.id2label[idx.item()]
61
+ answers.append((answer, prob.item()))
62
+
63
+ main_answer = answers[0][0]
64
+ confidence = answers[0][1]
65
+
66
+ alternative_answers = [f"{ans} ({prob:.2f})" for ans, prob in answers[1:]]
67
+
68
+ suggested_questions = suggest_questions(image)
69
+
70
+ return (
71
+ main_answer,
72
+ f"{confidence:.2f}",
73
+ ", ".join(alternative_answers),
74
+ "\n".join(suggested_questions)
75
+ )
76
+
77
+ # Create the Gradio interface
78
+ interface = gr.Interface(
79
+ fn=predict,
80
+ inputs=[
81
+ gr.Image(type="numpy"),
82
+ gr.Textbox(lines=1, placeholder="Ask a question...")
83
+ ],
84
+ outputs=[
85
+ gr.Textbox(label="Main Answer"),
86
+ gr.Textbox(label="Confidence Score"),
87
+ gr.Textbox(label="Alternative Answers"),
88
+ gr.Textbox(label="Suggested Questions")
89
+ ],
90
+ title="Enhanced ViLT Visual Question Answering",
91
+ description="Upload an image and ask a question about it. The model will provide the main answer, confidence score, alternative answers, and suggest additional questions."
92
+ )
93
+
94
+ # Launch the Gradio interface
95
+ interface.launch()