roberto ceraolo
commited on
Commit
·
d166583
1
Parent(s):
48a578d
app
Browse files- app.py +42 -34
- faiss_uploader.py +48 -0
- inference.py +379 -0
app.py
CHANGED
@@ -1,48 +1,56 @@
|
|
1 |
import gradio as gr
|
2 |
-
import
|
3 |
-
import
|
4 |
-
import
|
5 |
-
from
|
6 |
-
import shutil
|
7 |
|
8 |
-
|
9 |
-
DATA_DIR = "data"
|
10 |
-
INDEX_PATH = os.path.join(DATA_DIR, "faiss_index.index")
|
11 |
-
|
12 |
-
def save_index(file_obj):
|
13 |
"""
|
14 |
-
|
15 |
"""
|
16 |
-
# Create data directory if it doesn't exist
|
17 |
-
Path(DATA_DIR).mkdir(exist_ok=True)
|
18 |
-
|
19 |
-
# Check if index already exists
|
20 |
-
if os.path.exists(INDEX_PATH):
|
21 |
-
return "⚠️ A FAISS index already exists in the data directory. Please remove it first."
|
22 |
-
|
23 |
try:
|
24 |
-
|
25 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
26 |
|
27 |
-
|
28 |
-
|
29 |
-
return "✅ FAISS index successfully uploaded and saved!"
|
30 |
|
31 |
except Exception as e:
|
32 |
-
|
33 |
-
|
34 |
-
os.remove(INDEX_PATH)
|
35 |
-
return f"❌ Error: Invalid FAISS index file - {str(e)}"
|
36 |
|
37 |
# Create Gradio interface
|
38 |
-
|
39 |
-
fn=
|
40 |
-
inputs=
|
41 |
-
|
42 |
-
|
43 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
44 |
allow_flagging="never"
|
45 |
)
|
46 |
|
47 |
if __name__ == "__main__":
|
48 |
-
|
|
|
1 |
import gradio as gr
|
2 |
+
import sys
|
3 |
+
import logging
|
4 |
+
from inference import main, load_index, load_metadata
|
5 |
+
from PIL import Image
|
|
|
6 |
|
7 |
+
def run_pipeline(prompt, image):
|
|
|
|
|
|
|
|
|
8 |
"""
|
9 |
+
Gradio interface function to run the inference pipeline
|
10 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
try:
|
12 |
+
logging.info("Loading required data...")
|
13 |
+
index = load_index()
|
14 |
+
metadata_df = load_metadata()
|
15 |
+
|
16 |
+
logging.info("Starting inference pipeline...")
|
17 |
+
results = main(prompt, image, index, metadata_df)
|
18 |
+
|
19 |
+
# Return the generated image and similar products
|
20 |
+
image_path = results['generated_image_path']
|
21 |
+
similar_products = results['similar_products']
|
22 |
+
|
23 |
+
image_output = Image.open(image_path)
|
24 |
+
|
25 |
+
# Format product URLs as a numbered list with similarity scores
|
26 |
+
product_urls = []
|
27 |
+
for i, product in enumerate(similar_products, 1):
|
28 |
+
similarity = 1 / (1 + product['distance'])
|
29 |
+
product_urls.append(f"{i}. Similarity: {similarity:.2f}\nProduct: {product['product_url']}\n")
|
30 |
|
31 |
+
formatted_urls = "\n".join(product_urls)
|
32 |
+
return image_output, formatted_urls
|
|
|
33 |
|
34 |
except Exception as e:
|
35 |
+
logging.error(f"Error in pipeline: {str(e)}")
|
36 |
+
return None, None
|
|
|
|
|
37 |
|
38 |
# Create Gradio interface
|
39 |
+
iface = gr.Interface(
|
40 |
+
fn=run_pipeline,
|
41 |
+
inputs=[
|
42 |
+
gr.Textbox(label="Enter your prompt", placeholder="e.g., modern living room with minimalist furniture"),
|
43 |
+
gr.Image(label="Upload control image", type="filepath")
|
44 |
+
],
|
45 |
+
outputs=[
|
46 |
+
gr.Image(label="Generated Image"),
|
47 |
+
gr.Textbox(label="Similar IKEA Products", lines=15)
|
48 |
+
],
|
49 |
+
title="Interior Design Image Generator",
|
50 |
+
description="Upload an image and provide a prompt to generate interior design variations and find similar IKEA products.",
|
51 |
+
theme="default",
|
52 |
allow_flagging="never"
|
53 |
)
|
54 |
|
55 |
if __name__ == "__main__":
|
56 |
+
iface.launch(share=True)
|
faiss_uploader.py
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import os
|
3 |
+
import faiss
|
4 |
+
import numpy as np
|
5 |
+
from pathlib import Path
|
6 |
+
import shutil
|
7 |
+
|
8 |
+
# Constants
|
9 |
+
DATA_DIR = "data"
|
10 |
+
INDEX_PATH = os.path.join(DATA_DIR, "faiss_index.index")
|
11 |
+
|
12 |
+
def save_index(file_obj):
|
13 |
+
"""
|
14 |
+
Save uploaded FAISS index to the data directory
|
15 |
+
"""
|
16 |
+
# Create data directory if it doesn't exist
|
17 |
+
Path(DATA_DIR).mkdir(exist_ok=True)
|
18 |
+
|
19 |
+
# Check if index already exists
|
20 |
+
if os.path.exists(INDEX_PATH):
|
21 |
+
return "⚠️ A FAISS index already exists in the data directory. Please remove it first."
|
22 |
+
|
23 |
+
try:
|
24 |
+
# Copy the temporary file to our target location
|
25 |
+
shutil.copy2(file_obj.name, INDEX_PATH)
|
26 |
+
|
27 |
+
# Verify the saved file is a valid FAISS index
|
28 |
+
faiss.read_index(INDEX_PATH)
|
29 |
+
return "✅ FAISS index successfully uploaded and saved!"
|
30 |
+
|
31 |
+
except Exception as e:
|
32 |
+
# If there was an error, remove the file if it was created
|
33 |
+
if os.path.exists(INDEX_PATH):
|
34 |
+
os.remove(INDEX_PATH)
|
35 |
+
return f"❌ Error: Invalid FAISS index file - {str(e)}"
|
36 |
+
|
37 |
+
# Create Gradio interface
|
38 |
+
demo = gr.Interface(
|
39 |
+
fn=save_index,
|
40 |
+
inputs=gr.File(label="Upload FAISS Index", file_types=[".index"]),
|
41 |
+
outputs=gr.Textbox(label="Status"),
|
42 |
+
title="FAISS Index Uploader",
|
43 |
+
description="Upload a FAISS index file to store in the HuggingFace Space data directory.",
|
44 |
+
allow_flagging="never"
|
45 |
+
)
|
46 |
+
|
47 |
+
if __name__ == "__main__":
|
48 |
+
demo.launch()
|
inference.py
ADDED
@@ -0,0 +1,379 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
inference steps - to be run whenever a new image is uploaded
|
3 |
+
input: image and textual prompt
|
4 |
+
steps:
|
5 |
+
1. generate an image (or more than one) with stable diffusion
|
6 |
+
2. GPT-4o - detect main pieces of furniture
|
7 |
+
3. perform object detection on the image looking for the main pieces of furniture
|
8 |
+
4. generate embeddings for the image and the subimages
|
9 |
+
5. perform a similarity search on the index of ikea products
|
10 |
+
6. return the results: generated image, main pieces of furniture, similar ikea products
|
11 |
+
"""
|
12 |
+
|
13 |
+
import logging
|
14 |
+
from datetime import datetime
|
15 |
+
|
16 |
+
# Set up logging to both file and console
|
17 |
+
log_filename = f"inference_{datetime.now().strftime('%Y%m%d_%H%M%S')}.log"
|
18 |
+
logging.basicConfig(
|
19 |
+
level=logging.INFO,
|
20 |
+
format='%(asctime)s - %(levelname)s - %(message)s',
|
21 |
+
handlers=[
|
22 |
+
logging.FileHandler(log_filename),
|
23 |
+
logging.StreamHandler()
|
24 |
+
]
|
25 |
+
)
|
26 |
+
|
27 |
+
import replicate
|
28 |
+
from pydantic import BaseModel, Field
|
29 |
+
from openai import OpenAI
|
30 |
+
import base64
|
31 |
+
import numpy as np
|
32 |
+
import os
|
33 |
+
import requests
|
34 |
+
from PIL import Image
|
35 |
+
from io import BytesIO
|
36 |
+
import sys
|
37 |
+
import dotenv
|
38 |
+
import pandas as pd
|
39 |
+
import faiss
|
40 |
+
|
41 |
+
logging.info("Loading environment variables...")
|
42 |
+
dotenv.load_dotenv()
|
43 |
+
client = OpenAI()
|
44 |
+
# step 1
|
45 |
+
|
46 |
+
class Prompt(BaseModel):
|
47 |
+
prompt: str = Field(description="A detailed prompt for a diffusion model")
|
48 |
+
|
49 |
+
def generate_prompt_for_flux(user_prompt):
|
50 |
+
completion = client.beta.chat.completions.parse(
|
51 |
+
model="gpt-4o-mini",
|
52 |
+
messages=[
|
53 |
+
{
|
54 |
+
"role": "user",
|
55 |
+
"content": [
|
56 |
+
{"type": "text", "text": f"Generate a prompt for a diffusion model that is a more detailed version of the following prompt: {user_prompt}. Keep it succinct but more descriptive than the original. Just return 1 sentence, do not include specific elements, since you do not know the available space in the room. "},
|
57 |
+
],
|
58 |
+
}
|
59 |
+
],
|
60 |
+
response_format=Prompt,
|
61 |
+
)
|
62 |
+
|
63 |
+
analysis = completion.choices[0].message.parsed
|
64 |
+
return analysis.prompt
|
65 |
+
|
66 |
+
def search_similar_products(image_path, index, metadata_df, top_k=5):
|
67 |
+
"""
|
68 |
+
Search for similar products given a local image
|
69 |
+
|
70 |
+
Args:
|
71 |
+
image_path (str): Path to the query image
|
72 |
+
index: FAISS index
|
73 |
+
metadata_df: DataFrame containing product metadata
|
74 |
+
top_k (int): Number of similar items to return
|
75 |
+
|
76 |
+
Returns:
|
77 |
+
list: List of dictionaries containing similar product information
|
78 |
+
"""
|
79 |
+
logging.info(f"\nGenerating embedding for image: {image_path}")
|
80 |
+
# Generate embedding for the query image
|
81 |
+
output = replicate.run(
|
82 |
+
"krthr/clip-embeddings:1c0371070cb827ec3c7f2f28adcdde54b50dcd239aa6faea0bc98b174ef03fb4",
|
83 |
+
input={"image": image_path}
|
84 |
+
)
|
85 |
+
if 'embedding' in output:
|
86 |
+
query_embedding = np.array(output['embedding'])
|
87 |
+
else:
|
88 |
+
query_embedding = np.array(output).astype('float32').reshape(1, -1)
|
89 |
+
|
90 |
+
logging.info("Searching FAISS index...")
|
91 |
+
# Search the index
|
92 |
+
distances, indices = index.search(np.array([query_embedding]), top_k)
|
93 |
+
|
94 |
+
logging.info("Processing search results...")
|
95 |
+
# Get the metadata for the similar products
|
96 |
+
results = []
|
97 |
+
for idx, distance in zip(indices[0], distances[0]):
|
98 |
+
result = {
|
99 |
+
'product_url': metadata_df.iloc[idx]['product_url'],
|
100 |
+
'image_url': metadata_df.iloc[idx]['image_url'],
|
101 |
+
'distance': float(distance)
|
102 |
+
}
|
103 |
+
results.append(result)
|
104 |
+
|
105 |
+
return results
|
106 |
+
|
107 |
+
|
108 |
+
def find_similar_ikea_products(image_input, index, metadata_df, top_k=5, process_detections_list=None):
|
109 |
+
"""
|
110 |
+
Convenience function to find similar IKEA products
|
111 |
+
|
112 |
+
Args:
|
113 |
+
image_input (str): Path to local image file or URL of image
|
114 |
+
top_k (int): Number of similar items to return
|
115 |
+
process_detections_list (list, optional): List of detection dictionaries containing bbox and label
|
116 |
+
"""
|
117 |
+
logging.info("\nProcessing full image...")
|
118 |
+
# First process the full image
|
119 |
+
logging.info("\nProcessing full image:")
|
120 |
+
# Handle both local files and URLs
|
121 |
+
if image_input.startswith(('http://', 'https://')):
|
122 |
+
logging.info(f"Processing URL image: {image_input}")
|
123 |
+
image_path = image_input
|
124 |
+
else:
|
125 |
+
logging.info(f"Processing local image: {image_input}")
|
126 |
+
if not os.path.exists(image_input):
|
127 |
+
raise FileNotFoundError(f"Local image file not found: {image_input}")
|
128 |
+
image_path = open(image_input, "rb")
|
129 |
+
|
130 |
+
results = search_similar_products(image_path, index, metadata_df, top_k)
|
131 |
+
|
132 |
+
logging.info(f"\nTop {top_k} similar products (overall image):")
|
133 |
+
for i, result in enumerate(results, 1):
|
134 |
+
logging.info(f"\n{i}. Similarity score: {1 / (1 + result['distance']):.3f}")
|
135 |
+
logging.info(f"Product URL: {result['product_url']}")
|
136 |
+
logging.info(f"Image URL: {result['image_url']}")
|
137 |
+
|
138 |
+
# If detections are provided, process sub-images
|
139 |
+
if process_detections_list:
|
140 |
+
logging.info("\nProcessing object detections...")
|
141 |
+
# Load the image
|
142 |
+
if isinstance(image_input, str) and image_input.startswith(('http://', 'https://')):
|
143 |
+
image_path = image_input
|
144 |
+
else:
|
145 |
+
# local image processing -
|
146 |
+
image = Image.open(image_input)
|
147 |
+
|
148 |
+
# Process each detection
|
149 |
+
detections = process_detections_list['detections']
|
150 |
+
for i, detection in enumerate(detections):
|
151 |
+
logging.info(f"\nProcessing detection {i+1}: {detection['label']}")
|
152 |
+
logging.info(f"Confidence: {detection['confidence']:.3f}")
|
153 |
+
|
154 |
+
# Extract bounding box coordinates
|
155 |
+
x1, y1, x2, y2 = detection['bbox']
|
156 |
+
|
157 |
+
logging.info(f"Cropping image to bbox: ({x1}, {y1}, {x2}, {y2})")
|
158 |
+
# Crop the image to the bounding box
|
159 |
+
cropped = image.crop((x1, y1, x2, y2))
|
160 |
+
|
161 |
+
# Save the cropped image temporarily
|
162 |
+
temp_path = f"temp_crop_{i}.jpg"
|
163 |
+
cropped.save(temp_path)
|
164 |
+
logging.info(f"Saved temporary crop to: {temp_path}")
|
165 |
+
|
166 |
+
try:
|
167 |
+
# Find similar products for this crop
|
168 |
+
logging.info(f"\nFinding similar products for {detection['label']}:")
|
169 |
+
sub_results = search_similar_products(temp_path, index, metadata_df, top_k)
|
170 |
+
|
171 |
+
logging.info(f"\nTop {top_k} similar products for {detection['label']}:")
|
172 |
+
for j, result in enumerate(sub_results, 1):
|
173 |
+
logging.info(f"\n{j}. Similarity score: {1 / (1 + result['distance']):.3f}")
|
174 |
+
logging.info(f"Product URL: {result['product_url']}")
|
175 |
+
logging.info(f"Image URL: {result['image_url']}")
|
176 |
+
|
177 |
+
except Exception as e:
|
178 |
+
logging.error(f"Error processing detection {i+1}: {e}")
|
179 |
+
|
180 |
+
finally:
|
181 |
+
# Clean up temporary file
|
182 |
+
if os.path.exists(temp_path):
|
183 |
+
logging.info(f"Cleaning up temporary file: {temp_path}")
|
184 |
+
os.remove(temp_path)
|
185 |
+
|
186 |
+
return results
|
187 |
+
|
188 |
+
|
189 |
+
def generate_image(prompt, control_image, guidance_scale, output_quality, negative_prompt, control_strength):
|
190 |
+
logging.info("\nGenerating image with Stable Diffusion...")
|
191 |
+
logging.info(f"Prompt: {prompt}")
|
192 |
+
guidance_scale = 2.5
|
193 |
+
output_quality = 100
|
194 |
+
negative_prompt = "low quality, ugly, distorted, artefacts"
|
195 |
+
control_strength = 0.45
|
196 |
+
# Modify the input handling to work with file paths
|
197 |
+
if isinstance(control_image, str):
|
198 |
+
image = open(control_image, "rb")
|
199 |
+
else:
|
200 |
+
image = control_image
|
201 |
+
|
202 |
+
input = {
|
203 |
+
"prompt": prompt,
|
204 |
+
"control_image": image,
|
205 |
+
"guidance_scale": guidance_scale,
|
206 |
+
"output_quality": output_quality,
|
207 |
+
"negative_prompt": negative_prompt,
|
208 |
+
"control_strength": control_strength
|
209 |
+
}
|
210 |
+
|
211 |
+
logging.info("Running image generation model...")
|
212 |
+
output = replicate.run(
|
213 |
+
"xlabs-ai/flux-dev-controlnet:9a8db105db745f8b11ad3afe5c8bd892428b2a43ade0b67edc4e0ccd52ff2fda",
|
214 |
+
input=input
|
215 |
+
)
|
216 |
+
logging.info("Saving generated images...")
|
217 |
+
for index, item in enumerate(output):
|
218 |
+
with open(f"output_{index}.jpg", "wb") as file:
|
219 |
+
file.write(item.read())
|
220 |
+
logging.info(f"Saved output_{index}.jpg")
|
221 |
+
return output
|
222 |
+
|
223 |
+
# step 2
|
224 |
+
|
225 |
+
|
226 |
+
def analyze_image(image_path):
|
227 |
+
logging.info(f"\nAnalyzing image with GPT-4V: {image_path}")
|
228 |
+
|
229 |
+
class ImageAnalysis(BaseModel):
|
230 |
+
objects: list[str] = Field(description="A list of objects in the image")
|
231 |
+
|
232 |
+
# Function to encode the image
|
233 |
+
def encode_image(image_path):
|
234 |
+
logging.info("Encoding image to base64...")
|
235 |
+
with open(image_path, "rb") as image_file:
|
236 |
+
return base64.b64encode(image_file.read()).decode('utf-8')
|
237 |
+
|
238 |
+
# Encode the image
|
239 |
+
encoded_image = encode_image(image_path)
|
240 |
+
|
241 |
+
logging.info("Sending request to GPT-4o-mini vision...")
|
242 |
+
completion = client.beta.chat.completions.parse(
|
243 |
+
model="gpt-4o-mini",
|
244 |
+
messages=[
|
245 |
+
{
|
246 |
+
"role": "user",
|
247 |
+
"content": [
|
248 |
+
{"type": "text", "text": "Analyze this image and list the main objects of furniture in the image."},
|
249 |
+
{
|
250 |
+
"type": "image_url",
|
251 |
+
"image_url": {
|
252 |
+
"url": f"data:image/jpeg;base64,{encoded_image}"
|
253 |
+
},
|
254 |
+
},
|
255 |
+
],
|
256 |
+
}
|
257 |
+
],
|
258 |
+
response_format=ImageAnalysis,
|
259 |
+
)
|
260 |
+
|
261 |
+
analysis = completion.choices[0].message.parsed
|
262 |
+
main_objects = ', '.join(analysis.objects)
|
263 |
+
logging.info(f"""
|
264 |
+
Objects: {', '.join(analysis.objects)}
|
265 |
+
""")
|
266 |
+
return main_objects
|
267 |
+
|
268 |
+
# step 3
|
269 |
+
|
270 |
+
def detect_objects(image_path, main_objects):
|
271 |
+
logging.info(f"\nDetecting objects in image: {image_path}")
|
272 |
+
image = open(image_path, "rb")
|
273 |
+
input = {
|
274 |
+
"image": image,
|
275 |
+
"query": main_objects,
|
276 |
+
"box_threshold": 0.2,
|
277 |
+
"text_threshold": 0.2
|
278 |
+
}
|
279 |
+
|
280 |
+
logging.info("Running object detection model...")
|
281 |
+
output = replicate.run(
|
282 |
+
"adirik/grounding-dino:efd10a8ddc57ea28773327e881ce95e20cc1d734c589f7dd01d2036921ed78aa",
|
283 |
+
input=input
|
284 |
+
)
|
285 |
+
logging.info("Detection results:")
|
286 |
+
logging.info(output)
|
287 |
+
return output
|
288 |
+
|
289 |
+
# step 4, 5
|
290 |
+
def search_index(image_path, index, metadata_df, main_objects = None, top_k=5):
|
291 |
+
logging.info(f"\nSearching index for similar products to: {image_path}")
|
292 |
+
#process_detections_list = detect_objects(image_path, main_objects)
|
293 |
+
results = find_similar_ikea_products(image_path, index, metadata_df, top_k=5)
|
294 |
+
return results
|
295 |
+
|
296 |
+
|
297 |
+
|
298 |
+
|
299 |
+
def main(prompt, control_image, index, metadata_df):
|
300 |
+
"""
|
301 |
+
Main function to orchestrate the entire inference pipeline
|
302 |
+
|
303 |
+
Args:
|
304 |
+
prompt (str): Text prompt for image generation
|
305 |
+
control_image: Input image for controlled generation
|
306 |
+
index: FAISS index for similarity search
|
307 |
+
metadata_df: DataFrame containing product metadata
|
308 |
+
|
309 |
+
Returns:
|
310 |
+
dict: Results containing generated images, detected objects, and similar products
|
311 |
+
"""
|
312 |
+
|
313 |
+
|
314 |
+
logging.info("\nStarting inference pipeline...")
|
315 |
+
results = {}
|
316 |
+
|
317 |
+
logging.info("\nStep 0: Generating a detailed prompt for the diffusion model...")
|
318 |
+
# Step 0: Generate a detailed prompt for the diffusion model
|
319 |
+
prompt = generate_prompt_for_flux(prompt)
|
320 |
+
logging.info(f"\nGenerated prompt: {prompt}")
|
321 |
+
|
322 |
+
logging.info("\nStep 1: Generating image...")
|
323 |
+
# Step 1: Generate image
|
324 |
+
generated_images = generate_image(
|
325 |
+
prompt=prompt,
|
326 |
+
control_image=control_image,
|
327 |
+
guidance_scale=2.5,
|
328 |
+
output_quality=100,
|
329 |
+
negative_prompt="low quality, ugly, distorted, artefacts",
|
330 |
+
control_strength=0.95
|
331 |
+
)
|
332 |
+
results['generated_images'] = generated_images
|
333 |
+
results['generated_image_path'] = "output_0.jpg"
|
334 |
+
|
335 |
+
# logging.info("\nStep 2: Analyzing generated image...")
|
336 |
+
# # Step 2: Analyze generated image with GPT-4V
|
337 |
+
obj_detection = False
|
338 |
+
if obj_detection:
|
339 |
+
main_objects = analyze_image("output_0.jpg") # Using the first generated image
|
340 |
+
results['detected_furniture'] = main_objects
|
341 |
+
|
342 |
+
logging.info("\nSteps 3-5: Detecting objects and searching for similar products...")
|
343 |
+
# Step 3 & 4 & 5: Detect objects and search for similar products
|
344 |
+
similar_products = search_index(
|
345 |
+
image_path="output_0.jpg",
|
346 |
+
index=index,
|
347 |
+
metadata_df=metadata_df,
|
348 |
+
top_k=5
|
349 |
+
)
|
350 |
+
results['similar_products'] = similar_products
|
351 |
+
|
352 |
+
return results
|
353 |
+
|
354 |
+
def load_index():
|
355 |
+
logging.info("\nLoading FAISS index...")
|
356 |
+
return faiss.read_index("data/ikea_faiss.index")
|
357 |
+
|
358 |
+
def load_metadata():
|
359 |
+
logging.info("Loading metadata...")
|
360 |
+
return pd.read_csv("data/filtered_metadata.csv")
|
361 |
+
|
362 |
+
if __name__ == "__main__":
|
363 |
+
# Example usage
|
364 |
+
logging.info("\nStarting program...")
|
365 |
+
|
366 |
+
if len(sys.argv) != 3:
|
367 |
+
logging.error("Usage: python inference.py <prompt> <link to control image>")
|
368 |
+
sys.exit(1)
|
369 |
+
|
370 |
+
prompt = sys.argv[1]
|
371 |
+
control_image = sys.argv[2]
|
372 |
+
|
373 |
+
logging.info("\nLoading required data...")
|
374 |
+
# Load your FAISS index and metadata_df here
|
375 |
+
index = load_index()
|
376 |
+
metadata_df = load_metadata()
|
377 |
+
|
378 |
+
results = main(prompt, control_image, index, metadata_df)
|
379 |
+
logging.info("\nPipeline completed successfully!")
|