feat_preo_cmod
#3
by
AmithAdiraju1694
- opened
- app.py +37 -177
- inference/config.py +16 -26
- inference/preprocess_image.py +57 -4
- inference/translate.py +41 -16
- pages.py +214 -0
- utils.py +15 -0
app.py
CHANGED
@@ -1,204 +1,64 @@
|
|
|
|
|
|
|
|
1 |
import streamlit as st
|
2 |
from streamlit import session_state as sst
|
3 |
-
from typing import List, Optional
|
4 |
import asyncio
|
5 |
-
import pandas as pd
|
6 |
-
|
7 |
-
from inference.translate import (
|
8 |
-
extract_filter_img,
|
9 |
-
transcribe_menu_model
|
10 |
-
)
|
11 |
-
|
12 |
-
from inference.config import DEBUG_MODE
|
13 |
-
from PIL import Image
|
14 |
-
import time
|
15 |
-
|
16 |
-
from concurrent.futures import ThreadPoolExecutor, as_completed
|
17 |
-
import os
|
18 |
-
|
19 |
-
# Setting workers to be 70% of all available virtual cpus in system
|
20 |
-
cpu_count = os.cpu_count()
|
21 |
-
pool = ThreadPoolExecutor(max_workers=int(cpu_count*0.7) )
|
22 |
|
|
|
23 |
# Initialize session state variable to start with home page
|
24 |
if "page" not in sst:
|
25 |
sst["page"] = "Home"
|
26 |
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
Parameters:
|
33 |
-
page: str, required.
|
34 |
-
|
35 |
-
Returns:
|
36 |
-
None
|
37 |
-
"""
|
38 |
-
|
39 |
-
sst["page"] = page
|
40 |
-
|
41 |
-
async def main_page() -> None:
|
42 |
-
"""
|
43 |
-
Function that contains content of main page i.e., image uploader and submit button to navigate to next page.
|
44 |
-
Upon submit , control goes to model inference 'page'.
|
45 |
-
|
46 |
-
Parameters:
|
47 |
-
None
|
48 |
-
|
49 |
-
Returns:
|
50 |
-
None
|
51 |
-
"""
|
52 |
-
|
53 |
-
# Streamlit app
|
54 |
-
first_title = st.empty()
|
55 |
-
first_title.title("App that explains your menu items ")
|
56 |
-
|
57 |
|
58 |
-
|
59 |
-
|
60 |
-
type=["jpg", "jpeg", "png"])
|
61 |
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
if uploaded_file is not None:
|
67 |
-
image = Image.open(uploaded_file)
|
68 |
-
|
69 |
-
# Only show if user wants to see
|
70 |
-
if st.checkbox('Show Uploaded Image'):
|
71 |
-
st.image(image,
|
72 |
-
caption='Uploaded Image',
|
73 |
-
use_column_width=True)
|
74 |
-
|
75 |
-
sst["input_image"] = image
|
76 |
-
|
77 |
-
# Submit button
|
78 |
-
st.button("Submit",
|
79 |
-
on_click = navigate_to,
|
80 |
-
args = ("Inference",))
|
81 |
-
|
82 |
-
|
83 |
-
st.info("""This application is for education purposes only. It uses AI, hence it's dietary
|
84 |
-
recommendations are not to be taken as medical advice, author doesn't bear responsibility
|
85 |
-
for incorrect dietary recommendations. Please proceed with caution.
|
86 |
-
""")
|
87 |
|
|
|
88 |
|
89 |
-
|
|
|
|
|
90 |
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
waiting for all threads to be done.
|
95 |
|
96 |
-
Parameters:
|
97 |
-
inp_texts: List[str], required -> List of strings, containing item names of a menu in english.
|
98 |
-
|
99 |
-
Returns:
|
100 |
-
None
|
101 |
-
"""
|
102 |
-
|
103 |
-
df = pd.DataFrame([('ITEM NAME', 'EXPLANATION')]
|
104 |
-
)
|
105 |
-
|
106 |
-
sl_table = st.table(df)
|
107 |
-
tp_futures = { pool.submit(transcribe_menu_model, mi): mi for mi in inp_texts }
|
108 |
|
109 |
-
for tpftr in as_completed(tp_futures):
|
110 |
|
111 |
-
|
112 |
-
|
113 |
-
try:
|
114 |
-
exp = tpftr.result()
|
115 |
-
sl_table.add_rows([(item,exp)] )
|
116 |
-
|
117 |
-
except Exception as e:
|
118 |
-
print("Could not add a new row dynamically, because of this error:", e)
|
119 |
-
|
120 |
-
return
|
121 |
-
|
122 |
-
|
123 |
-
async def model_inference():
|
124 |
-
|
125 |
"""
|
126 |
-
|
127 |
-
and toggles state between pages if needed.
|
128 |
|
129 |
-
Parameters:
|
130 |
-
None
|
131 |
Returns:
|
132 |
None
|
133 |
-
|
134 |
"""
|
135 |
-
|
136 |
-
second_title = st.empty()
|
137 |
-
second_title.title(" Using ML to explain your menu items ... ")
|
138 |
-
|
139 |
-
if "input_image" in sst:
|
140 |
-
|
141 |
-
image = sst["input_image"]
|
142 |
-
|
143 |
-
msg1 = st.empty()
|
144 |
-
msg1.write("Pre-processing and extracting text out of your image ....")
|
145 |
-
st_filter = time.perf_counter()
|
146 |
-
|
147 |
-
# Call the extract_filter_img function
|
148 |
-
filtered_text = await extract_filter_img(image)
|
149 |
-
en_filter = time.perf_counter()
|
150 |
|
151 |
-
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
elif num_items_detected > 0:
|
157 |
-
st.write(f"Detected {num_items_detected} menu items from your input image ... ")
|
158 |
-
|
159 |
-
msg2 = st.empty()
|
160 |
-
msg2.write("All pre-processing done, transcribing your menu items now ....")
|
161 |
-
st_trans_llm = time.perf_counter()
|
162 |
-
|
163 |
-
await dist_llm_inference(filtered_text)
|
164 |
-
|
165 |
-
msg3 = st.empty()
|
166 |
-
msg3.write("Done transcribing ... ")
|
167 |
-
en_trans_llm = time.perf_counter()
|
168 |
-
|
169 |
-
msg1.empty(); msg2.empty(); msg3.empty()
|
170 |
-
st.success("Image processed successfully! " )
|
171 |
-
|
172 |
-
if DEBUG_MODE:
|
173 |
-
filter_time_sec = en_filter - st_filter
|
174 |
-
llm_time_sec = en_trans_llm - st_trans_llm
|
175 |
-
total_time_sec = filter_time_sec + llm_time_sec
|
176 |
-
|
177 |
-
st.write("Time took to extract and filter text {}".format(filter_time_sec))
|
178 |
-
st.write("Time took to summarize by LLM {}".format(llm_time_sec))
|
179 |
-
st.write('Overall time taken in seconds: {}'.format(total_time_sec))
|
180 |
-
|
181 |
|
182 |
-
|
183 |
-
|
184 |
-
|
185 |
-
|
186 |
-
|
187 |
-
|
188 |
-
|
189 |
-
|
190 |
-
async def main():
|
191 |
-
"""
|
192 |
-
Function that toggles between pages based on state variables.
|
193 |
|
194 |
-
Parameters:
|
195 |
-
None
|
196 |
-
Returns:
|
197 |
-
None
|
198 |
-
"""
|
199 |
-
if sst["page"] == "Home":
|
200 |
-
await main_page()
|
201 |
elif sst["page"] == "Inference":
|
202 |
-
await
|
203 |
|
204 |
asyncio.run(main())
|
|
|
1 |
+
from utils import navigate_to
|
2 |
+
from pages import manual_input_page, image_input_page, model_inference_page
|
3 |
+
|
4 |
import streamlit as st
|
5 |
from streamlit import session_state as sst
|
|
|
6 |
import asyncio
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
7 |
|
8 |
+
#TODO: Fix model inference and post processing function befor emoving ot production.
|
9 |
# Initialize session state variable to start with home page
|
10 |
if "page" not in sst:
|
11 |
sst["page"] = "Home"
|
12 |
|
13 |
+
# function to remove all sesion variables from sst, except page.
|
14 |
+
def reset_sst():
|
15 |
+
for key in list(sst.keys()):
|
16 |
+
if key != "page":
|
17 |
+
sst.pop(key, None)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
18 |
|
19 |
+
# Landing page function
|
20 |
+
async def landing_page():
|
|
|
21 |
|
22 |
+
st.title("We will explain your menu like never before!")
|
23 |
+
st.write("\n")
|
24 |
+
st.write("\n")
|
25 |
+
st.write("\n")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
26 |
|
27 |
+
c1, c2= st.columns(2)
|
28 |
|
29 |
+
with c1:
|
30 |
+
# Navigate to manual input page if user clicks on the button
|
31 |
+
st.button("Enter Items Manually", on_click=navigate_to, args=("ManualInput",))
|
32 |
|
33 |
+
with c2:
|
34 |
+
# Navigate to image input page if user clicks on the button
|
35 |
+
st.button("Upload Items from Image", on_click=navigate_to, args=("ImageInput",))
|
|
|
36 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
37 |
|
|
|
38 |
|
39 |
+
# Main function to handle navigation
|
40 |
+
async def main():
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
41 |
"""
|
42 |
+
Main function that handles the navigation logic based on the current page.
|
|
|
43 |
|
|
|
|
|
44 |
Returns:
|
45 |
None
|
|
|
46 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
47 |
|
48 |
+
# Navigation logic
|
49 |
+
if sst["page"] == "Home":
|
50 |
+
reset_sst() # reset all session state variables before navigating to the landing page
|
51 |
+
await landing_page() # Call the landing page function
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
52 |
|
53 |
+
elif sst["page"] == "ManualInput":
|
54 |
+
reset_sst() # reset all session state variables before navigating to the landing page
|
55 |
+
await manual_input_page() # Call the manual input page function
|
56 |
+
|
57 |
+
elif sst["page"] == "ImageInput":
|
58 |
+
reset_sst() # reset all session state variables before navigating to the landing page
|
59 |
+
await image_input_page() # Call the image input page function
|
|
|
|
|
|
|
|
|
60 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
61 |
elif sst["page"] == "Inference":
|
62 |
+
await model_inference_page() # Call the model inference page function
|
63 |
|
64 |
asyncio.run(main())
|
inference/config.py
CHANGED
@@ -1,33 +1,23 @@
|
|
1 |
-
|
2 |
-
|
3 |
|
4 |
-
Item
|
5 |
-
|
6 |
-
|
7 |
-
It goes well with: White basmati rice or Indian flat bread.\n
|
8 |
-
Allergens: Paneer may cause digestive discomfort and intolerance to some.\n
|
9 |
-
Food Category: Vegetarian, Vegans may not like it, as paneer is usually made from cow milk.
|
10 |
|
|
|
|
|
11 |
|
12 |
-
Item -> rumali roti.\n
|
13 |
-
Explanation -> Major Ingredients here: roti.\n
|
14 |
-
How it is made: A small soft bread, made to size of a napkin ( a.k.a 'rumal' in hindi ); usually made with a combination of whole wheat and all purpose flour.\n
|
15 |
-
It goes well with: Most indian gravies such as palak paneer, tomato curry etc.\n
|
16 |
-
Allergens: May contain gluten, which is known to cause digestive discomfort and intolerance to some.\n
|
17 |
-
Food Category: Vegetarian, Vegan.
|
18 |
|
19 |
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
|
|
26 |
|
|
|
27 |
|
28 |
-
|
29 |
-
Item ->
|
30 |
-
"""
|
31 |
-
|
32 |
-
DEBUG_MODE = False
|
33 |
-
DEVICE = 'cpu'
|
|
|
1 |
+
import torch
|
2 |
+
import re
|
3 |
|
4 |
+
model_inf_inp_prompt = "INSTRUCTION: given food item name, explain these things:(major ingredients,making process,portion & spicy/sweet,pairs with,allergens,food type(veg/non-veg/vegan)). ensure to get allergens and food category factually correct.Item Name: {} "
|
5 |
+
header_pattern = r'Item Name: (.*?)\. Major Ingredients: (.*?)\. Making Process: (.*?)\. Portion and Spice Level: (.*?)\. Pairs With: (.*?)\. Allergens: (.*?)\. Food Type: (.*?)\.\s*</s>'
|
6 |
+
dots_pattern = re.compile(r'\.{3,}')
|
|
|
|
|
|
|
7 |
|
8 |
+
DEBUG_MODE = True
|
9 |
+
model_name = "AmithAdiraju1694/gpt-neo-125M_menuitemexp"
|
10 |
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
|
12 |
|
13 |
+
def get_device():
|
14 |
+
if torch.cuda.is_available():
|
15 |
+
device = torch.device("cuda")
|
16 |
+
print(f"Using GPU: {torch.cuda.get_device_name(0)}") #get the name of the GPU being used.
|
17 |
+
else:
|
18 |
+
device = torch.device("cpu")
|
19 |
+
print("Using CPU")
|
20 |
|
21 |
+
return device
|
22 |
|
23 |
+
DEVICE = get_device()
|
|
|
|
|
|
|
|
|
|
inference/preprocess_image.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1 |
|
2 |
import numpy as np
|
3 |
-
from typing import List, Tuple, Optional, AnyStr
|
4 |
import nltk
|
5 |
nltk.download("stopwords")
|
6 |
nltk.download('punkt')
|
@@ -53,11 +53,64 @@ def image_to_np_arr(image) -> np.array:
|
|
53 |
return np.array(image)
|
54 |
|
55 |
async def process_extracted_text(raw_extrc_text: List[Tuple]) -> List[AnyStr]:
|
56 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
57 |
output_texts = []
|
58 |
for _, extr_text, _ in raw_extrc_text:
|
59 |
# remove all numbers, special characters from a string
|
60 |
prcsd_txt = preprocess_text(extr_text)
|
61 |
-
if len(prcsd_txt.split(" ")
|
|
|
62 |
|
63 |
-
return output_texts
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
|
2 |
import numpy as np
|
3 |
+
from typing import List, Tuple, Optional, AnyStr, Dict
|
4 |
import nltk
|
5 |
nltk.download("stopwords")
|
6 |
nltk.download('punkt')
|
|
|
53 |
return np.array(image)
|
54 |
|
55 |
async def process_extracted_text(raw_extrc_text: List[Tuple]) -> List[AnyStr]:
|
56 |
+
"""
|
57 |
+
Function that processes extracted text by removing numbers and special characters,
|
58 |
+
and filters out text with less than 2 words.
|
59 |
+
|
60 |
+
Parameters:
|
61 |
+
raw_extrc_text: List[Tuple], required -> A list of tuples containing extracted text.
|
62 |
+
|
63 |
+
Returns:
|
64 |
+
List[AnyStr] -> A list of processed text strings.
|
65 |
+
"""
|
66 |
output_texts = []
|
67 |
for _, extr_text, _ in raw_extrc_text:
|
68 |
# remove all numbers, special characters from a string
|
69 |
prcsd_txt = preprocess_text(extr_text)
|
70 |
+
if len(prcsd_txt.split(" ")) >= 2:
|
71 |
+
output_texts.append(prcsd_txt)
|
72 |
|
73 |
+
return output_texts
|
74 |
+
|
75 |
+
def post_process_gen_outputs(gen_output: List[str], header_pattern: str, dots_pattern:str) -> List[Dict]:
|
76 |
+
|
77 |
+
# Define the regular expression pattern to match section names and placeholders
|
78 |
+
headers = ["Item Name", "Major Ingredients", "Making Process", "Portion and Spice Level", "Pairs With", "Allergens", "Food Type"]
|
79 |
+
|
80 |
+
# Function to clean the strings
|
81 |
+
def clean_string(input_string):
|
82 |
+
parts = input_string.split(',')
|
83 |
+
cleaned_parts = [part.strip() for part in parts if part.strip()]
|
84 |
+
return ', '.join(cleaned_parts)
|
85 |
+
|
86 |
+
for i in range(len(gen_output)):
|
87 |
+
# Find all matches
|
88 |
+
matches = re.findall(header_pattern, gen_output[i])
|
89 |
+
|
90 |
+
# Since re.findall returns a list of tuples, we need to extract the first tuple
|
91 |
+
if matches:
|
92 |
+
result = dict(zip(headers,matches[0]))
|
93 |
+
result['Major Ingredients'] = clean_string(result['Major Ingredients'])
|
94 |
+
|
95 |
+
# if any of dictionary values strings are emtpy, replace it with string "Sorry, can't explain this."
|
96 |
+
for k in result.keys():
|
97 |
+
if len(result[k]) < 3 or any(header in result[k] for header in headers):
|
98 |
+
result[k] = "Sorry, can't explain this."
|
99 |
+
|
100 |
+
gen_output[i] = result
|
101 |
+
|
102 |
+
else:
|
103 |
+
if headers[1] in gen_output[i]:
|
104 |
+
|
105 |
+
gen_output[i] = {"May contain misleading explanation":
|
106 |
+
dots_pattern.sub('' ,
|
107 |
+
gen_output[i].split(headers[1]
|
108 |
+
)[1].strip().replace('</s>', '')
|
109 |
+
)
|
110 |
+
}
|
111 |
+
else:
|
112 |
+
gen_output[i] = {"Sorry, can't explain this item": "NA"}
|
113 |
+
|
114 |
+
gen_output[i].pop('Item Name', None)
|
115 |
+
return gen_output
|
116 |
+
|
inference/translate.py
CHANGED
@@ -2,29 +2,50 @@ import streamlit as st
|
|
2 |
|
3 |
from inference.preprocess_image import (
|
4 |
image_to_np_arr,
|
5 |
-
process_extracted_text
|
|
|
6 |
)
|
7 |
|
8 |
-
from inference.config import
|
|
|
|
|
|
|
|
|
|
|
|
|
9 |
from typing import List, Tuple, Optional, AnyStr, Dict
|
10 |
-
from transformers import
|
11 |
import easyocr
|
12 |
import time
|
13 |
|
14 |
use_gpu = True
|
15 |
-
if DEVICE == 'cpu': use_gpu = False
|
16 |
|
17 |
@st.cache_resource
|
18 |
def load_models(item_summarizer: AnyStr) -> Tuple:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
19 |
text_extractor = easyocr.Reader(['en'],
|
20 |
gpu = use_gpu
|
21 |
)
|
22 |
-
|
23 |
-
model
|
|
|
|
|
24 |
|
25 |
return (text_extractor, tokenizer, model)
|
26 |
|
27 |
-
text_extractor,item_tokenizer,item_summarizer = load_models(item_summarizer =
|
28 |
|
29 |
|
30 |
# Define your extract_filter_img function
|
@@ -78,20 +99,24 @@ async def extract_filter_img(image) -> Dict:
|
|
78 |
|
79 |
def transcribe_menu_model(menu_text: List[AnyStr]) -> Dict:
|
80 |
|
81 |
-
prompt_item =
|
82 |
-
|
83 |
-
|
84 |
-
"""
|
85 |
input_ids = item_tokenizer(prompt_item, return_tensors="pt").input_ids
|
86 |
|
87 |
outputs = item_summarizer.generate(input_ids,
|
88 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
89 |
)
|
90 |
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
)
|
95 |
|
96 |
def classify_menu_text(extrc_str: List[AnyStr]) -> List[AnyStr]:
|
97 |
return extrc_str
|
|
|
2 |
|
3 |
from inference.preprocess_image import (
|
4 |
image_to_np_arr,
|
5 |
+
process_extracted_text,
|
6 |
+
post_process_gen_outputs
|
7 |
)
|
8 |
|
9 |
+
from inference.config import (
|
10 |
+
model_inf_inp_prompt,
|
11 |
+
header_pattern,
|
12 |
+
dots_pattern,
|
13 |
+
DEVICE,
|
14 |
+
model_name
|
15 |
+
)
|
16 |
from typing import List, Tuple, Optional, AnyStr, Dict
|
17 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM
|
18 |
import easyocr
|
19 |
import time
|
20 |
|
21 |
use_gpu = True
|
22 |
+
if DEVICE.type == 'cpu': use_gpu = False
|
23 |
|
24 |
@st.cache_resource
|
25 |
def load_models(item_summarizer: AnyStr) -> Tuple:
|
26 |
+
|
27 |
+
"""
|
28 |
+
Function to load the models required for the inference process. Cached to avoid loading the models, every time the function is called.
|
29 |
+
|
30 |
+
Parameters:
|
31 |
+
item_summarizer: str, required -> The LLM model name to be used for item summarization.
|
32 |
+
|
33 |
+
Returns:
|
34 |
+
Tuple -> Tuple containing the required models for the inference process.
|
35 |
+
"""
|
36 |
+
|
37 |
+
# model to extract text from image
|
38 |
text_extractor = easyocr.Reader(['en'],
|
39 |
gpu = use_gpu
|
40 |
)
|
41 |
+
|
42 |
+
# tokenizer and model to generate item summary
|
43 |
+
tokenizer = AutoTokenizer.from_pretrained(item_summarizer)
|
44 |
+
model = AutoModelForCausalLM.from_pretrained(item_summarizer)
|
45 |
|
46 |
return (text_extractor, tokenizer, model)
|
47 |
|
48 |
+
text_extractor,item_tokenizer,item_summarizer = load_models(item_summarizer = model_name)
|
49 |
|
50 |
|
51 |
# Define your extract_filter_img function
|
|
|
99 |
|
100 |
def transcribe_menu_model(menu_text: List[AnyStr]) -> Dict:
|
101 |
|
102 |
+
prompt_item = model_inf_inp_prompt.format(menu_text)
|
|
|
|
|
|
|
103 |
input_ids = item_tokenizer(prompt_item, return_tensors="pt").input_ids
|
104 |
|
105 |
outputs = item_summarizer.generate(input_ids,
|
106 |
+
max_new_tokens = 512,
|
107 |
+
num_beams = 4,
|
108 |
+
pad_token_id = item_tokenizer.pad_token_id,
|
109 |
+
eos_token_id = item_tokenizer.eos_token_id,
|
110 |
+
bos_token_id = item_tokenizer.bos_token_id
|
111 |
+
)
|
112 |
+
|
113 |
+
prediction = item_tokenizer.batch_decode(outputs,
|
114 |
+
skip_special_tokens=False
|
115 |
)
|
116 |
|
117 |
+
postpro_output = post_process_gen_outputs( prediction, header_pattern, dots_pattern )[0]
|
118 |
+
|
119 |
+
return postpro_output
|
|
|
120 |
|
121 |
def classify_menu_text(extrc_str: List[AnyStr]) -> List[AnyStr]:
|
122 |
return extrc_str
|
pages.py
ADDED
@@ -0,0 +1,214 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
from streamlit import session_state as sst
|
3 |
+
|
4 |
+
|
5 |
+
from utils import navigate_to
|
6 |
+
from inference.config import DEBUG_MODE
|
7 |
+
|
8 |
+
from inference.translate import extract_filter_img, transcribe_menu_model,classify_menu_text
|
9 |
+
from inference.preprocess_image import preprocess_text
|
10 |
+
|
11 |
+
import os
|
12 |
+
import time
|
13 |
+
import pandas as pd
|
14 |
+
from PIL import Image
|
15 |
+
from typing import List
|
16 |
+
import json
|
17 |
+
from concurrent.futures import ThreadPoolExecutor, as_completed
|
18 |
+
|
19 |
+
# Setting workers to be 70% of all available virtual cpus in system
|
20 |
+
cpu_count = os.cpu_count()
|
21 |
+
pool = ThreadPoolExecutor(max_workers=int(cpu_count*0.7) )
|
22 |
+
|
23 |
+
# Function that handles logic of explaining menu items from manual input
|
24 |
+
async def manual_input_page():
|
25 |
+
|
26 |
+
"""
|
27 |
+
Function that takes text input from user in input box of streamlit, user can add multiple text boxes and submit finally.
|
28 |
+
|
29 |
+
Parameters:
|
30 |
+
None
|
31 |
+
|
32 |
+
Returns:
|
33 |
+
List[str]: List of strings, containing item names of a menu in english.
|
34 |
+
"""
|
35 |
+
|
36 |
+
st.write("This is the Manual Input Page.")
|
37 |
+
st.write("Once done, click on 'Explain My Menu' button to get explanations for each item ... ")
|
38 |
+
|
39 |
+
inp_texts = []
|
40 |
+
num_text_boxes = st.number_input("Number of text boxes", min_value=1, step=1)
|
41 |
+
for i in range(num_text_boxes):
|
42 |
+
text_box = st.text_input(f"Food item {i+1}")
|
43 |
+
if text_box:
|
44 |
+
inp_texts.append(text_box)
|
45 |
+
|
46 |
+
if len(inp_texts) > 0:
|
47 |
+
|
48 |
+
# Show user submit button only if they have entered some text and set text in session state
|
49 |
+
sst["user_entered_items"] = inp_texts
|
50 |
+
st.button("Explain My Menu",on_click=navigate_to,args=("Inference",))
|
51 |
+
|
52 |
+
else:
|
53 |
+
st.write("Please enter some items to proceed ...")
|
54 |
+
|
55 |
+
|
56 |
+
st.button("Go back Home", on_click=navigate_to, args=("Home",))
|
57 |
+
|
58 |
+
|
59 |
+
# Function that handles logic of explaining menu items from image uploads
|
60 |
+
async def image_input_page():
|
61 |
+
"""
|
62 |
+
Function that contains content of main page i.e., image uploader and submit button to navigate to next page.
|
63 |
+
Upon submit , control goes to model inference 'page'.
|
64 |
+
|
65 |
+
Parameters:
|
66 |
+
None
|
67 |
+
|
68 |
+
Returns:
|
69 |
+
None
|
70 |
+
"""
|
71 |
+
|
72 |
+
st.write("This is the Image Input Page.")
|
73 |
+
|
74 |
+
# Streamlit function to upload an image from any device
|
75 |
+
uploaded_file = st.file_uploader("Choose an image...",
|
76 |
+
type=["jpg", "jpeg", "png"])
|
77 |
+
|
78 |
+
# Remove preivous states' value of input image if it exists
|
79 |
+
sst.pop('input_image', None)
|
80 |
+
|
81 |
+
# Submit button
|
82 |
+
if uploaded_file is not None:
|
83 |
+
image = Image.open(uploaded_file)
|
84 |
+
|
85 |
+
# Only show if user wants to see
|
86 |
+
if st.checkbox('Show Uploaded Image'):
|
87 |
+
st.image(image,
|
88 |
+
caption='Uploaded Image',
|
89 |
+
use_column_width=True)
|
90 |
+
|
91 |
+
sst["input_image"] = image
|
92 |
+
|
93 |
+
# Show user submit button only if they have uploaded an image
|
94 |
+
st.button("Translate My Menu",
|
95 |
+
on_click = navigate_to,
|
96 |
+
args = ("Inference",))
|
97 |
+
|
98 |
+
|
99 |
+
# Warning message to user
|
100 |
+
st.info("""This application is for education purposes only. It uses AI, hence it's dietary
|
101 |
+
recommendations are not to be taken as medical advice, author doesn't bear responsibility
|
102 |
+
for incorrect dietary recommendations. Please proceed with caution.
|
103 |
+
""")
|
104 |
+
|
105 |
+
# if user wants to go back, make sure to reset the session state
|
106 |
+
st.button("Go back Home", on_click=navigate_to, args=("Home",))
|
107 |
+
|
108 |
+
|
109 |
+
# Function that handles model inference
|
110 |
+
async def model_inference_page():
|
111 |
+
|
112 |
+
"""
|
113 |
+
Function that pre-processes input text from state variables, does concurrent inference
|
114 |
+
and toggles state between pages if needed.
|
115 |
+
|
116 |
+
Parameters:
|
117 |
+
None
|
118 |
+
Returns:
|
119 |
+
None
|
120 |
+
|
121 |
+
"""
|
122 |
+
|
123 |
+
second_title = st.empty()
|
124 |
+
second_title.title(" Using ML to explain your menu items ... ")
|
125 |
+
|
126 |
+
# User can either upload an image or enter text manually, we check for both
|
127 |
+
if "input_image" in sst:
|
128 |
+
image = sst["input_image"]
|
129 |
+
|
130 |
+
msg1 = st.empty()
|
131 |
+
msg1.write("Pre-processing and extracting text out of your image ....")
|
132 |
+
# Call the extract_filter_img function
|
133 |
+
filtered_text = await extract_filter_img(image)
|
134 |
+
num_items_detected = len(filtered_text)
|
135 |
+
|
136 |
+
|
137 |
+
if "user_entered_items" in sst:
|
138 |
+
user_text = sst["user_entered_items"]
|
139 |
+
st.write("Pre-processing and filtering text from user input ....")
|
140 |
+
|
141 |
+
filtered_text = [preprocess_text(ut) for ut in user_text]
|
142 |
+
|
143 |
+
num_items_detected = len(filtered_text)
|
144 |
+
|
145 |
+
|
146 |
+
# irrespective of source of user entry , we check if we have any items to process
|
147 |
+
if num_items_detected == 0:
|
148 |
+
st.write("We couldn't detect any menu items ( indian for now ) from your image, please try a different image by going back.")
|
149 |
+
|
150 |
+
elif num_items_detected > 0:
|
151 |
+
st.write(f"Detected {num_items_detected} menu items from your input image ... ")
|
152 |
+
|
153 |
+
msg2 = st.empty()
|
154 |
+
msg2.write("All pre-processing done, transcribing your menu items now ....")
|
155 |
+
st_trans_llm = time.perf_counter()
|
156 |
+
|
157 |
+
await dist_llm_inference(filtered_text)
|
158 |
+
|
159 |
+
msg3 = st.empty()
|
160 |
+
msg3.write("Done transcribing ... ")
|
161 |
+
en_trans_llm = time.perf_counter()
|
162 |
+
|
163 |
+
msg2.empty(); msg3.empty()
|
164 |
+
st.success("Image processed successfully! " )
|
165 |
+
|
166 |
+
# Some basic stats for debug mode
|
167 |
+
if DEBUG_MODE:
|
168 |
+
llm_time_sec = en_trans_llm - st_trans_llm
|
169 |
+
st.write("Time took to summarize by LLM {}".format(llm_time_sec))
|
170 |
+
|
171 |
+
|
172 |
+
# If user clicked in "translate_another" button reset all session state variables and go back to home
|
173 |
+
st.button("Go back Home", on_click=navigate_to, args=("Home",))
|
174 |
+
|
175 |
+
|
176 |
+
# Function that performs LLM inference on a single item
|
177 |
+
async def dist_llm_inference(inp_texts: List[str]) -> None:
|
178 |
+
|
179 |
+
"""
|
180 |
+
Function that performs concurrent LLM inference using threadpool. It displays
|
181 |
+
results of those threads that are done with execution, as a dynamic row to streamlit table, rather than
|
182 |
+
waiting for all threads to be done.
|
183 |
+
|
184 |
+
Parameters:
|
185 |
+
inp_texts: List[str], required -> List of strings, containing item names of a menu in english.
|
186 |
+
|
187 |
+
Returns:
|
188 |
+
None
|
189 |
+
"""
|
190 |
+
|
191 |
+
df = pd.DataFrame([('ITEM NAME', 'EXPLANATION')]
|
192 |
+
)
|
193 |
+
|
194 |
+
sl_table = st.table(df)
|
195 |
+
tp_futures = { pool.submit(transcribe_menu_model, mi): mi for mi in inp_texts }
|
196 |
+
|
197 |
+
for tpftr in as_completed(tp_futures):
|
198 |
+
|
199 |
+
item = tp_futures[tpftr]
|
200 |
+
|
201 |
+
try:
|
202 |
+
exp = tpftr.result()
|
203 |
+
|
204 |
+
|
205 |
+
sl_table.add_rows([(item,
|
206 |
+
str(exp ))
|
207 |
+
]
|
208 |
+
)
|
209 |
+
|
210 |
+
except Exception as e:
|
211 |
+
print("Could not add a new row dynamically, because of this error:", e)
|
212 |
+
|
213 |
+
return
|
214 |
+
|
utils.py
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
from streamlit import session_state as sst
|
3 |
+
def navigate_to(page: str) -> None:
|
4 |
+
"""
|
5 |
+
Function to set the current page in the state of streamlit. A helper for
|
6 |
+
simulating navigation in streamlit.
|
7 |
+
|
8 |
+
Parameters:
|
9 |
+
page: str, required.
|
10 |
+
|
11 |
+
Returns:
|
12 |
+
None
|
13 |
+
"""
|
14 |
+
|
15 |
+
sst["page"] = page
|