dejanseo commited on
Commit
9c70ea5
·
verified ·
1 Parent(s): 714bd70

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +125 -0
app.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import tensorflow as tf
3
+ import numpy as np
4
+ from PIL import Image
5
+ import requests
6
+ from io import BytesIO
7
+ from selenium import webdriver
8
+ from selenium.webdriver.chrome.service import Service
9
+ from selenium.webdriver.common.by import By
10
+ from webdriver_manager.chrome import ChromeDriverManager
11
+ import time
12
+ import pandas as pd
13
+ import base64
14
+
15
+ def load_model(model_path):
16
+ interpreter = tf.lite.Interpreter(model_path=model_path)
17
+ interpreter.allocate_tensors()
18
+ return interpreter
19
+
20
+ def preprocess_image(image, input_size):
21
+ image = image.convert('RGB')
22
+ image = image.resize((input_size, input_size))
23
+ image_np = np.array(image, dtype=np.float32)
24
+ image_np = np.expand_dims(image_np, axis=0)
25
+ image_np = image_np / 255.0 # Normalize to [0, 1]
26
+ return image_np
27
+
28
+ def run_inference(interpreter, input_data):
29
+ input_details = interpreter.get_input_details()
30
+ output_details = interpreter.get_output_details()
31
+
32
+ interpreter.set_tensor(input_details[0]['index'], input_data)
33
+ interpreter.invoke()
34
+
35
+ output_data_shopping_intent = interpreter.get_tensor(output_details[0]['index'])
36
+ output_data_sensitive = interpreter.get_tensor(output_details[1]['index'])
37
+
38
+ return output_data_shopping_intent, output_data_sensitive
39
+
40
+ def fetch_images_from_url(url):
41
+ options = webdriver.ChromeOptions()
42
+ options.add_argument('--headless')
43
+ options.add_argument('--no-sandbox')
44
+ options.add_argument('--disable-dev-shm-usage')
45
+ options.add_argument('--disable-gpu')
46
+
47
+ service = Service(ChromeDriverManager().install())
48
+ driver = webdriver.Chrome(service=service, options=options)
49
+ driver.get(url)
50
+
51
+ # Give the page some time to load and execute JavaScript
52
+ time.sleep(10)
53
+
54
+ images = driver.find_elements(By.TAG_NAME, 'img')
55
+ img_urls = [img.get_attribute('src') for img in images if img.get_attribute('src')]
56
+
57
+ driver.quit()
58
+ return img_urls
59
+
60
+ def image_to_base64(image):
61
+ buffered = BytesIO()
62
+ image.save(buffered, format="PNG")
63
+ return base64.b64encode(buffered.getvalue()).decode()
64
+
65
+ def main():
66
+ st.set_page_config(layout="wide")
67
+ st.title("Image Classification with TFLite")
68
+ st.write("Enter a URL to fetch and classify all images on the page.")
69
+
70
+ model_path = "model.tflite"
71
+ url = st.text_input("Enter URL")
72
+
73
+ if url:
74
+ img_urls = fetch_images_from_url(url)
75
+ if img_urls:
76
+ st.write(f"Found {len(img_urls)} images")
77
+ interpreter = load_model(model_path)
78
+ input_details = interpreter.get_input_details()
79
+ input_shape = input_details[0]['shape']
80
+ input_size = input_shape[1] # assuming square input
81
+
82
+ data = []
83
+ errors = []
84
+
85
+ for img_url in img_urls:
86
+ try:
87
+ response = requests.get(img_url)
88
+ image = Image.open(BytesIO(response.content))
89
+
90
+ input_data = preprocess_image(image, input_size)
91
+ output_data_shopping_intent, output_data_sensitive = run_inference(interpreter, input_data)
92
+
93
+ # Convert image to Base64
94
+ image.thumbnail((100, 100))
95
+ thumbnail_base64 = image_to_base64(image)
96
+ thumbnail_data_url = f"data:image/png;base64,{thumbnail_base64}"
97
+
98
+ data.append({
99
+ 'Thumbnail': thumbnail_data_url,
100
+ 'URL': img_url,
101
+ 'Shopping Intent': output_data_shopping_intent.flatten().tolist(),
102
+ 'Sensitivity': output_data_sensitive.flatten().tolist()
103
+ })
104
+ except Exception as e:
105
+ errors.append(f"Could not process image {img_url}: {e}")
106
+
107
+ # Convert data to DataFrame
108
+ df = pd.DataFrame(data)
109
+
110
+ # Configure DataFrame display with images, URLs, and classifications
111
+ st.data_editor(df, column_config={
112
+ "Thumbnail": st.column_config.ImageColumn("Thumbnail", help="Image thumbnails"),
113
+ "URL": st.column_config.LinkColumn("URL"),
114
+ "Shopping Intent": st.column_config.BarChartColumn("Shopping Intent", width="small"),
115
+ "Sensitivity": st.column_config.BarChartColumn("Sensitivity", width="small")
116
+ })
117
+
118
+ # Display errors in an expandable section
119
+ if errors:
120
+ with st.expander(f"Could not process {len(errors)} images"):
121
+ for error in errors:
122
+ st.write(error)
123
+
124
+ if __name__ == "__main__":
125
+ main()