vvinayakkk commited on
Commit
ad51373
·
verified ·
1 Parent(s): a6f1afa

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +114 -0
app.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import torch
3
+ from PIL import Image
4
+ from qwen_vl_utils import process_vision_info
5
+ from transformers import Qwen2VLForConditionalGeneration, AutoProcessor
6
+ import time
7
+
8
+
9
+ @st.cache_resource
10
+ def load_model():
11
+ model = Qwen2VLForConditionalGeneration.from_pretrained("Qwen/Qwen2-VL-2B-Instruct", trust_remote_code=True, torch_dtype=torch.float32).eval()
12
+ processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct", trust_remote_code=True)
13
+ return model, processor
14
+
15
+ model, processor = load_model()
16
+
17
+
18
+ st.title("Image Query App")
19
+
20
+
21
+ uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
22
+
23
+
24
+ st.sidebar.title("Suggested Questions")
25
+ predefined_questions = [
26
+ "What is the main object in this image?",
27
+ "Describe the scene in the image.",
28
+ "Are there any people in the image?",
29
+ "What is the background of the image?"
30
+ ]
31
+ selected_question = st.sidebar.radio("Choose a question", predefined_questions)
32
+
33
+
34
+ question = st.sidebar.text_input("Or ask your own question here:")
35
+
36
+
37
+ submit_button = st.sidebar.button("Submit")
38
+
39
+
40
+ response = ""
41
+
42
+ if uploaded_file is not None:
43
+ image = Image.open(uploaded_file)
44
+
45
+
46
+ original_size = image.size
47
+ st.write(f"Original image dimensions: {original_size}")
48
+
49
+
50
+ max_size = (700, 700)
51
+ if image.size[0] > 1000 or image.size[1] > 1000:
52
+ image.thumbnail(max_size)
53
+ resized_size = image.size
54
+ st.write(f"Image resized to: {resized_size}")
55
+ else:
56
+ st.write("Image size is within acceptable limits.")
57
+
58
+ if not question:
59
+ question = selected_question
60
+
61
+
62
+ if submit_button:
63
+ st.sidebar.markdown("<h3 style='color:blue;'>Fetching the answer might take 2-3 minutes depending on the question, hold tight while we process your request!</h3>", unsafe_allow_html=True)
64
+ start_time = time.time() # Start the timer
65
+
66
+ if question:
67
+
68
+ messages = [
69
+ {
70
+ "role": "user",
71
+ "content": [
72
+ {"type": "image", "image": image},
73
+ {"type": "text", "text": question},
74
+ ],
75
+ }
76
+ ]
77
+
78
+
79
+ text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
80
+ image_inputs, video_inputs = process_vision_info(messages)
81
+ inputs = processor(
82
+ text=[text],
83
+ images=image_inputs,
84
+ videos=video_inputs,
85
+ padding=True,
86
+ return_tensors="pt",
87
+ )
88
+
89
+
90
+ with st.spinner('Fetching the answer...'):
91
+ with torch.no_grad():
92
+ new_generated_ids = model.generate(**inputs, max_new_tokens=180)
93
+
94
+
95
+ new_generated_ids_trimmed = [
96
+ out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, new_generated_ids)
97
+ ]
98
+ response = processor.batch_decode(
99
+ new_generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
100
+ )[0]
101
+
102
+ else:
103
+ st.warning("Please enter a question.")
104
+
105
+ elapsed_time = time.time() - start_time # Calculate elapsed time
106
+
107
+
108
+ if response:
109
+ st.markdown(f"<h4 style='color:green;'>Response:</h4><p style='font-size:18px;'>{response}</p>", unsafe_allow_html=True)
110
+ st.markdown(f"<p style='color:gray;'>Time taken to fetch the answer: {elapsed_time:.2f} seconds</p>", unsafe_allow_html=True)
111
+
112
+
113
+ if uploaded_file is not None:
114
+ st.image(image, caption='Uploaded Image', use_column_width=True)