roberto ceraolo commited on
Commit
d166583
·
1 Parent(s): 48a578d
Files changed (3) hide show
  1. app.py +42 -34
  2. faiss_uploader.py +48 -0
  3. inference.py +379 -0
app.py CHANGED
@@ -1,48 +1,56 @@
1
  import gradio as gr
2
- import os
3
- import faiss
4
- import numpy as np
5
- from pathlib import Path
6
- import shutil
7
 
8
- # Constants
9
- DATA_DIR = "data"
10
- INDEX_PATH = os.path.join(DATA_DIR, "faiss_index.index")
11
-
12
- def save_index(file_obj):
13
  """
14
- Save uploaded FAISS index to the data directory
15
  """
16
- # Create data directory if it doesn't exist
17
- Path(DATA_DIR).mkdir(exist_ok=True)
18
-
19
- # Check if index already exists
20
- if os.path.exists(INDEX_PATH):
21
- return "⚠️ A FAISS index already exists in the data directory. Please remove it first."
22
-
23
  try:
24
- # Copy the temporary file to our target location
25
- shutil.copy2(file_obj.name, INDEX_PATH)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
 
27
- # Verify the saved file is a valid FAISS index
28
- faiss.read_index(INDEX_PATH)
29
- return "✅ FAISS index successfully uploaded and saved!"
30
 
31
  except Exception as e:
32
- # If there was an error, remove the file if it was created
33
- if os.path.exists(INDEX_PATH):
34
- os.remove(INDEX_PATH)
35
- return f"❌ Error: Invalid FAISS index file - {str(e)}"
36
 
37
  # Create Gradio interface
38
- demo = gr.Interface(
39
- fn=save_index,
40
- inputs=gr.File(label="Upload FAISS Index", file_types=[".index"]),
41
- outputs=gr.Textbox(label="Status"),
42
- title="FAISS Index Uploader",
43
- description="Upload a FAISS index file to store in the HuggingFace Space data directory.",
 
 
 
 
 
 
 
44
  allow_flagging="never"
45
  )
46
 
47
  if __name__ == "__main__":
48
- demo.launch()
 
1
  import gradio as gr
2
+ import sys
3
+ import logging
4
+ from inference import main, load_index, load_metadata
5
+ from PIL import Image
 
6
 
7
+ def run_pipeline(prompt, image):
 
 
 
 
8
  """
9
+ Gradio interface function to run the inference pipeline
10
  """
 
 
 
 
 
 
 
11
  try:
12
+ logging.info("Loading required data...")
13
+ index = load_index()
14
+ metadata_df = load_metadata()
15
+
16
+ logging.info("Starting inference pipeline...")
17
+ results = main(prompt, image, index, metadata_df)
18
+
19
+ # Return the generated image and similar products
20
+ image_path = results['generated_image_path']
21
+ similar_products = results['similar_products']
22
+
23
+ image_output = Image.open(image_path)
24
+
25
+ # Format product URLs as a numbered list with similarity scores
26
+ product_urls = []
27
+ for i, product in enumerate(similar_products, 1):
28
+ similarity = 1 / (1 + product['distance'])
29
+ product_urls.append(f"{i}. Similarity: {similarity:.2f}\nProduct: {product['product_url']}\n")
30
 
31
+ formatted_urls = "\n".join(product_urls)
32
+ return image_output, formatted_urls
 
33
 
34
  except Exception as e:
35
+ logging.error(f"Error in pipeline: {str(e)}")
36
+ return None, None
 
 
37
 
38
  # Create Gradio interface
39
+ iface = gr.Interface(
40
+ fn=run_pipeline,
41
+ inputs=[
42
+ gr.Textbox(label="Enter your prompt", placeholder="e.g., modern living room with minimalist furniture"),
43
+ gr.Image(label="Upload control image", type="filepath")
44
+ ],
45
+ outputs=[
46
+ gr.Image(label="Generated Image"),
47
+ gr.Textbox(label="Similar IKEA Products", lines=15)
48
+ ],
49
+ title="Interior Design Image Generator",
50
+ description="Upload an image and provide a prompt to generate interior design variations and find similar IKEA products.",
51
+ theme="default",
52
  allow_flagging="never"
53
  )
54
 
55
  if __name__ == "__main__":
56
+ iface.launch(share=True)
faiss_uploader.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import os
3
+ import faiss
4
+ import numpy as np
5
+ from pathlib import Path
6
+ import shutil
7
+
8
+ # Constants
9
+ DATA_DIR = "data"
10
+ INDEX_PATH = os.path.join(DATA_DIR, "faiss_index.index")
11
+
12
+ def save_index(file_obj):
13
+ """
14
+ Save uploaded FAISS index to the data directory
15
+ """
16
+ # Create data directory if it doesn't exist
17
+ Path(DATA_DIR).mkdir(exist_ok=True)
18
+
19
+ # Check if index already exists
20
+ if os.path.exists(INDEX_PATH):
21
+ return "⚠️ A FAISS index already exists in the data directory. Please remove it first."
22
+
23
+ try:
24
+ # Copy the temporary file to our target location
25
+ shutil.copy2(file_obj.name, INDEX_PATH)
26
+
27
+ # Verify the saved file is a valid FAISS index
28
+ faiss.read_index(INDEX_PATH)
29
+ return "✅ FAISS index successfully uploaded and saved!"
30
+
31
+ except Exception as e:
32
+ # If there was an error, remove the file if it was created
33
+ if os.path.exists(INDEX_PATH):
34
+ os.remove(INDEX_PATH)
35
+ return f"❌ Error: Invalid FAISS index file - {str(e)}"
36
+
37
+ # Create Gradio interface
38
+ demo = gr.Interface(
39
+ fn=save_index,
40
+ inputs=gr.File(label="Upload FAISS Index", file_types=[".index"]),
41
+ outputs=gr.Textbox(label="Status"),
42
+ title="FAISS Index Uploader",
43
+ description="Upload a FAISS index file to store in the HuggingFace Space data directory.",
44
+ allow_flagging="never"
45
+ )
46
+
47
+ if __name__ == "__main__":
48
+ demo.launch()
inference.py ADDED
@@ -0,0 +1,379 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ inference steps - to be run whenever a new image is uploaded
3
+ input: image and textual prompt
4
+ steps:
5
+ 1. generate an image (or more than one) with stable diffusion
6
+ 2. GPT-4o - detect main pieces of furniture
7
+ 3. perform object detection on the image looking for the main pieces of furniture
8
+ 4. generate embeddings for the image and the subimages
9
+ 5. perform a similarity search on the index of ikea products
10
+ 6. return the results: generated image, main pieces of furniture, similar ikea products
11
+ """
12
+
13
+ import logging
14
+ from datetime import datetime
15
+
16
+ # Set up logging to both file and console
17
+ log_filename = f"inference_{datetime.now().strftime('%Y%m%d_%H%M%S')}.log"
18
+ logging.basicConfig(
19
+ level=logging.INFO,
20
+ format='%(asctime)s - %(levelname)s - %(message)s',
21
+ handlers=[
22
+ logging.FileHandler(log_filename),
23
+ logging.StreamHandler()
24
+ ]
25
+ )
26
+
27
+ import replicate
28
+ from pydantic import BaseModel, Field
29
+ from openai import OpenAI
30
+ import base64
31
+ import numpy as np
32
+ import os
33
+ import requests
34
+ from PIL import Image
35
+ from io import BytesIO
36
+ import sys
37
+ import dotenv
38
+ import pandas as pd
39
+ import faiss
40
+
41
+ logging.info("Loading environment variables...")
42
+ dotenv.load_dotenv()
43
+ client = OpenAI()
44
+ # step 1
45
+
46
+ class Prompt(BaseModel):
47
+ prompt: str = Field(description="A detailed prompt for a diffusion model")
48
+
49
+ def generate_prompt_for_flux(user_prompt):
50
+ completion = client.beta.chat.completions.parse(
51
+ model="gpt-4o-mini",
52
+ messages=[
53
+ {
54
+ "role": "user",
55
+ "content": [
56
+ {"type": "text", "text": f"Generate a prompt for a diffusion model that is a more detailed version of the following prompt: {user_prompt}. Keep it succinct but more descriptive than the original. Just return 1 sentence, do not include specific elements, since you do not know the available space in the room. "},
57
+ ],
58
+ }
59
+ ],
60
+ response_format=Prompt,
61
+ )
62
+
63
+ analysis = completion.choices[0].message.parsed
64
+ return analysis.prompt
65
+
66
+ def search_similar_products(image_path, index, metadata_df, top_k=5):
67
+ """
68
+ Search for similar products given a local image
69
+
70
+ Args:
71
+ image_path (str): Path to the query image
72
+ index: FAISS index
73
+ metadata_df: DataFrame containing product metadata
74
+ top_k (int): Number of similar items to return
75
+
76
+ Returns:
77
+ list: List of dictionaries containing similar product information
78
+ """
79
+ logging.info(f"\nGenerating embedding for image: {image_path}")
80
+ # Generate embedding for the query image
81
+ output = replicate.run(
82
+ "krthr/clip-embeddings:1c0371070cb827ec3c7f2f28adcdde54b50dcd239aa6faea0bc98b174ef03fb4",
83
+ input={"image": image_path}
84
+ )
85
+ if 'embedding' in output:
86
+ query_embedding = np.array(output['embedding'])
87
+ else:
88
+ query_embedding = np.array(output).astype('float32').reshape(1, -1)
89
+
90
+ logging.info("Searching FAISS index...")
91
+ # Search the index
92
+ distances, indices = index.search(np.array([query_embedding]), top_k)
93
+
94
+ logging.info("Processing search results...")
95
+ # Get the metadata for the similar products
96
+ results = []
97
+ for idx, distance in zip(indices[0], distances[0]):
98
+ result = {
99
+ 'product_url': metadata_df.iloc[idx]['product_url'],
100
+ 'image_url': metadata_df.iloc[idx]['image_url'],
101
+ 'distance': float(distance)
102
+ }
103
+ results.append(result)
104
+
105
+ return results
106
+
107
+
108
+ def find_similar_ikea_products(image_input, index, metadata_df, top_k=5, process_detections_list=None):
109
+ """
110
+ Convenience function to find similar IKEA products
111
+
112
+ Args:
113
+ image_input (str): Path to local image file or URL of image
114
+ top_k (int): Number of similar items to return
115
+ process_detections_list (list, optional): List of detection dictionaries containing bbox and label
116
+ """
117
+ logging.info("\nProcessing full image...")
118
+ # First process the full image
119
+ logging.info("\nProcessing full image:")
120
+ # Handle both local files and URLs
121
+ if image_input.startswith(('http://', 'https://')):
122
+ logging.info(f"Processing URL image: {image_input}")
123
+ image_path = image_input
124
+ else:
125
+ logging.info(f"Processing local image: {image_input}")
126
+ if not os.path.exists(image_input):
127
+ raise FileNotFoundError(f"Local image file not found: {image_input}")
128
+ image_path = open(image_input, "rb")
129
+
130
+ results = search_similar_products(image_path, index, metadata_df, top_k)
131
+
132
+ logging.info(f"\nTop {top_k} similar products (overall image):")
133
+ for i, result in enumerate(results, 1):
134
+ logging.info(f"\n{i}. Similarity score: {1 / (1 + result['distance']):.3f}")
135
+ logging.info(f"Product URL: {result['product_url']}")
136
+ logging.info(f"Image URL: {result['image_url']}")
137
+
138
+ # If detections are provided, process sub-images
139
+ if process_detections_list:
140
+ logging.info("\nProcessing object detections...")
141
+ # Load the image
142
+ if isinstance(image_input, str) and image_input.startswith(('http://', 'https://')):
143
+ image_path = image_input
144
+ else:
145
+ # local image processing -
146
+ image = Image.open(image_input)
147
+
148
+ # Process each detection
149
+ detections = process_detections_list['detections']
150
+ for i, detection in enumerate(detections):
151
+ logging.info(f"\nProcessing detection {i+1}: {detection['label']}")
152
+ logging.info(f"Confidence: {detection['confidence']:.3f}")
153
+
154
+ # Extract bounding box coordinates
155
+ x1, y1, x2, y2 = detection['bbox']
156
+
157
+ logging.info(f"Cropping image to bbox: ({x1}, {y1}, {x2}, {y2})")
158
+ # Crop the image to the bounding box
159
+ cropped = image.crop((x1, y1, x2, y2))
160
+
161
+ # Save the cropped image temporarily
162
+ temp_path = f"temp_crop_{i}.jpg"
163
+ cropped.save(temp_path)
164
+ logging.info(f"Saved temporary crop to: {temp_path}")
165
+
166
+ try:
167
+ # Find similar products for this crop
168
+ logging.info(f"\nFinding similar products for {detection['label']}:")
169
+ sub_results = search_similar_products(temp_path, index, metadata_df, top_k)
170
+
171
+ logging.info(f"\nTop {top_k} similar products for {detection['label']}:")
172
+ for j, result in enumerate(sub_results, 1):
173
+ logging.info(f"\n{j}. Similarity score: {1 / (1 + result['distance']):.3f}")
174
+ logging.info(f"Product URL: {result['product_url']}")
175
+ logging.info(f"Image URL: {result['image_url']}")
176
+
177
+ except Exception as e:
178
+ logging.error(f"Error processing detection {i+1}: {e}")
179
+
180
+ finally:
181
+ # Clean up temporary file
182
+ if os.path.exists(temp_path):
183
+ logging.info(f"Cleaning up temporary file: {temp_path}")
184
+ os.remove(temp_path)
185
+
186
+ return results
187
+
188
+
189
+ def generate_image(prompt, control_image, guidance_scale, output_quality, negative_prompt, control_strength):
190
+ logging.info("\nGenerating image with Stable Diffusion...")
191
+ logging.info(f"Prompt: {prompt}")
192
+ guidance_scale = 2.5
193
+ output_quality = 100
194
+ negative_prompt = "low quality, ugly, distorted, artefacts"
195
+ control_strength = 0.45
196
+ # Modify the input handling to work with file paths
197
+ if isinstance(control_image, str):
198
+ image = open(control_image, "rb")
199
+ else:
200
+ image = control_image
201
+
202
+ input = {
203
+ "prompt": prompt,
204
+ "control_image": image,
205
+ "guidance_scale": guidance_scale,
206
+ "output_quality": output_quality,
207
+ "negative_prompt": negative_prompt,
208
+ "control_strength": control_strength
209
+ }
210
+
211
+ logging.info("Running image generation model...")
212
+ output = replicate.run(
213
+ "xlabs-ai/flux-dev-controlnet:9a8db105db745f8b11ad3afe5c8bd892428b2a43ade0b67edc4e0ccd52ff2fda",
214
+ input=input
215
+ )
216
+ logging.info("Saving generated images...")
217
+ for index, item in enumerate(output):
218
+ with open(f"output_{index}.jpg", "wb") as file:
219
+ file.write(item.read())
220
+ logging.info(f"Saved output_{index}.jpg")
221
+ return output
222
+
223
+ # step 2
224
+
225
+
226
+ def analyze_image(image_path):
227
+ logging.info(f"\nAnalyzing image with GPT-4V: {image_path}")
228
+
229
+ class ImageAnalysis(BaseModel):
230
+ objects: list[str] = Field(description="A list of objects in the image")
231
+
232
+ # Function to encode the image
233
+ def encode_image(image_path):
234
+ logging.info("Encoding image to base64...")
235
+ with open(image_path, "rb") as image_file:
236
+ return base64.b64encode(image_file.read()).decode('utf-8')
237
+
238
+ # Encode the image
239
+ encoded_image = encode_image(image_path)
240
+
241
+ logging.info("Sending request to GPT-4o-mini vision...")
242
+ completion = client.beta.chat.completions.parse(
243
+ model="gpt-4o-mini",
244
+ messages=[
245
+ {
246
+ "role": "user",
247
+ "content": [
248
+ {"type": "text", "text": "Analyze this image and list the main objects of furniture in the image."},
249
+ {
250
+ "type": "image_url",
251
+ "image_url": {
252
+ "url": f"data:image/jpeg;base64,{encoded_image}"
253
+ },
254
+ },
255
+ ],
256
+ }
257
+ ],
258
+ response_format=ImageAnalysis,
259
+ )
260
+
261
+ analysis = completion.choices[0].message.parsed
262
+ main_objects = ', '.join(analysis.objects)
263
+ logging.info(f"""
264
+ Objects: {', '.join(analysis.objects)}
265
+ """)
266
+ return main_objects
267
+
268
+ # step 3
269
+
270
+ def detect_objects(image_path, main_objects):
271
+ logging.info(f"\nDetecting objects in image: {image_path}")
272
+ image = open(image_path, "rb")
273
+ input = {
274
+ "image": image,
275
+ "query": main_objects,
276
+ "box_threshold": 0.2,
277
+ "text_threshold": 0.2
278
+ }
279
+
280
+ logging.info("Running object detection model...")
281
+ output = replicate.run(
282
+ "adirik/grounding-dino:efd10a8ddc57ea28773327e881ce95e20cc1d734c589f7dd01d2036921ed78aa",
283
+ input=input
284
+ )
285
+ logging.info("Detection results:")
286
+ logging.info(output)
287
+ return output
288
+
289
+ # step 4, 5
290
+ def search_index(image_path, index, metadata_df, main_objects = None, top_k=5):
291
+ logging.info(f"\nSearching index for similar products to: {image_path}")
292
+ #process_detections_list = detect_objects(image_path, main_objects)
293
+ results = find_similar_ikea_products(image_path, index, metadata_df, top_k=5)
294
+ return results
295
+
296
+
297
+
298
+
299
+ def main(prompt, control_image, index, metadata_df):
300
+ """
301
+ Main function to orchestrate the entire inference pipeline
302
+
303
+ Args:
304
+ prompt (str): Text prompt for image generation
305
+ control_image: Input image for controlled generation
306
+ index: FAISS index for similarity search
307
+ metadata_df: DataFrame containing product metadata
308
+
309
+ Returns:
310
+ dict: Results containing generated images, detected objects, and similar products
311
+ """
312
+
313
+
314
+ logging.info("\nStarting inference pipeline...")
315
+ results = {}
316
+
317
+ logging.info("\nStep 0: Generating a detailed prompt for the diffusion model...")
318
+ # Step 0: Generate a detailed prompt for the diffusion model
319
+ prompt = generate_prompt_for_flux(prompt)
320
+ logging.info(f"\nGenerated prompt: {prompt}")
321
+
322
+ logging.info("\nStep 1: Generating image...")
323
+ # Step 1: Generate image
324
+ generated_images = generate_image(
325
+ prompt=prompt,
326
+ control_image=control_image,
327
+ guidance_scale=2.5,
328
+ output_quality=100,
329
+ negative_prompt="low quality, ugly, distorted, artefacts",
330
+ control_strength=0.95
331
+ )
332
+ results['generated_images'] = generated_images
333
+ results['generated_image_path'] = "output_0.jpg"
334
+
335
+ # logging.info("\nStep 2: Analyzing generated image...")
336
+ # # Step 2: Analyze generated image with GPT-4V
337
+ obj_detection = False
338
+ if obj_detection:
339
+ main_objects = analyze_image("output_0.jpg") # Using the first generated image
340
+ results['detected_furniture'] = main_objects
341
+
342
+ logging.info("\nSteps 3-5: Detecting objects and searching for similar products...")
343
+ # Step 3 & 4 & 5: Detect objects and search for similar products
344
+ similar_products = search_index(
345
+ image_path="output_0.jpg",
346
+ index=index,
347
+ metadata_df=metadata_df,
348
+ top_k=5
349
+ )
350
+ results['similar_products'] = similar_products
351
+
352
+ return results
353
+
354
+ def load_index():
355
+ logging.info("\nLoading FAISS index...")
356
+ return faiss.read_index("data/ikea_faiss.index")
357
+
358
+ def load_metadata():
359
+ logging.info("Loading metadata...")
360
+ return pd.read_csv("data/filtered_metadata.csv")
361
+
362
+ if __name__ == "__main__":
363
+ # Example usage
364
+ logging.info("\nStarting program...")
365
+
366
+ if len(sys.argv) != 3:
367
+ logging.error("Usage: python inference.py <prompt> <link to control image>")
368
+ sys.exit(1)
369
+
370
+ prompt = sys.argv[1]
371
+ control_image = sys.argv[2]
372
+
373
+ logging.info("\nLoading required data...")
374
+ # Load your FAISS index and metadata_df here
375
+ index = load_index()
376
+ metadata_df = load_metadata()
377
+
378
+ results = main(prompt, control_image, index, metadata_df)
379
+ logging.info("\nPipeline completed successfully!")