|
""" |
|
inference steps - to be run whenever a new image is uploaded |
|
input: image and textual prompt |
|
steps: |
|
1. generate an image (or more than one) with stable diffusion |
|
2. GPT-4o - detect main pieces of furniture |
|
3. perform object detection on the image looking for the main pieces of furniture |
|
4. generate embeddings for the image and the subimages |
|
5. perform a similarity search on the index of ikea products |
|
6. return the results: generated image, main pieces of furniture, similar ikea products |
|
""" |
|
|
|
import logging |
|
from datetime import datetime |
|
|
|
|
|
log_filename = f"inference_{datetime.now().strftime('%Y%m%d_%H%M%S')}.log" |
|
logging.basicConfig( |
|
level=logging.INFO, |
|
format='%(asctime)s - %(levelname)s - %(message)s', |
|
handlers=[ |
|
logging.FileHandler(log_filename), |
|
logging.StreamHandler() |
|
] |
|
) |
|
|
|
import replicate |
|
from pydantic import BaseModel, Field |
|
from openai import OpenAI |
|
import base64 |
|
import numpy as np |
|
import os |
|
import requests |
|
from PIL import Image |
|
from io import BytesIO |
|
import sys |
|
import dotenv |
|
import pandas as pd |
|
import faiss |
|
|
|
logging.info("Loading environment variables...") |
|
dotenv.load_dotenv() |
|
client = OpenAI() |
|
|
|
|
|
class Prompt(BaseModel): |
|
prompt: str = Field(description="A detailed prompt for a diffusion model") |
|
|
|
def generate_prompt_for_flux(user_prompt): |
|
completion = client.beta.chat.completions.parse( |
|
model="gpt-4o-mini", |
|
messages=[ |
|
{ |
|
"role": "user", |
|
"content": [ |
|
{"type": "text", "text": f"Generate a prompt for a diffusion model that is a more detailed version of the following prompt: {user_prompt}. Keep it succinct but more descriptive than the original. Just return a few words, listing possible relevant furniture elements and objects that will be present in the image. "}, |
|
], |
|
} |
|
], |
|
response_format=Prompt, |
|
) |
|
|
|
analysis = completion.choices[0].message.parsed |
|
return analysis.prompt |
|
|
|
def search_similar_products(image_path, index, metadata_df, top_k=5): |
|
""" |
|
Search for similar products given a local image |
|
|
|
Args: |
|
image_path (str): Path to the query image |
|
index: FAISS index |
|
metadata_df: DataFrame containing product metadata |
|
top_k (int): Number of similar items to return |
|
|
|
Returns: |
|
list: List of dictionaries containing similar product information |
|
""" |
|
logging.info(f"\nGenerating embedding for image: {image_path}") |
|
|
|
output = replicate.run( |
|
"krthr/clip-embeddings:1c0371070cb827ec3c7f2f28adcdde54b50dcd239aa6faea0bc98b174ef03fb4", |
|
input={"image": image_path} |
|
) |
|
if 'embedding' in output: |
|
query_embedding = np.array(output['embedding']) |
|
else: |
|
query_embedding = np.array(output).astype('float32').reshape(1, -1) |
|
|
|
logging.info("Searching FAISS index...") |
|
|
|
distances, indices = index.search(np.array([query_embedding]), top_k) |
|
|
|
logging.info("Processing search results...") |
|
|
|
results = [] |
|
for idx, distance in zip(indices[0], distances[0]): |
|
result = { |
|
'product_url': metadata_df.iloc[idx]['product_url'], |
|
'image_url': metadata_df.iloc[idx]['image_url'], |
|
'distance': float(distance) |
|
} |
|
results.append(result) |
|
|
|
return results |
|
|
|
|
|
def find_similar_ikea_products(image_input, index, metadata_df, top_k=5, process_detections_list=None): |
|
""" |
|
Convenience function to find similar IKEA products |
|
|
|
Args: |
|
image_input (str): Path to local image file or URL of image |
|
top_k (int): Number of similar items to return |
|
process_detections_list (list, optional): List of detection dictionaries containing bbox and label |
|
""" |
|
logging.info("\nProcessing full image...") |
|
|
|
logging.info("\nProcessing full image:") |
|
|
|
if image_input.startswith(('http://', 'https://')): |
|
logging.info(f"Processing URL image: {image_input}") |
|
image_path = image_input |
|
else: |
|
logging.info(f"Processing local image: {image_input}") |
|
if not os.path.exists(image_input): |
|
raise FileNotFoundError(f"Local image file not found: {image_input}") |
|
image_path = open(image_input, "rb") |
|
|
|
results = search_similar_products(image_path, index, metadata_df, top_k) |
|
|
|
logging.info(f"\nTop {top_k} similar products (overall image):") |
|
for i, result in enumerate(results, 1): |
|
logging.info(f"\n{i}. Similarity score: {1 / (1 + result['distance']):.3f}") |
|
logging.info(f"Product URL: {result['product_url']}") |
|
logging.info(f"Image URL: {result['image_url']}") |
|
|
|
|
|
if process_detections_list: |
|
logging.info("\nProcessing object detections...") |
|
|
|
if isinstance(image_input, str) and image_input.startswith(('http://', 'https://')): |
|
image_path = image_input |
|
else: |
|
|
|
image = Image.open(image_input) |
|
|
|
|
|
detections = process_detections_list['detections'] |
|
for i, detection in enumerate(detections): |
|
logging.info(f"\nProcessing detection {i+1}: {detection['label']}") |
|
logging.info(f"Confidence: {detection['confidence']:.3f}") |
|
|
|
|
|
x1, y1, x2, y2 = detection['bbox'] |
|
|
|
logging.info(f"Cropping image to bbox: ({x1}, {y1}, {x2}, {y2})") |
|
|
|
cropped = image.crop((x1, y1, x2, y2)) |
|
|
|
|
|
temp_path = f"temp_crop_{i}.jpg" |
|
cropped.save(temp_path) |
|
logging.info(f"Saved temporary crop to: {temp_path}") |
|
|
|
try: |
|
|
|
logging.info(f"\nFinding similar products for {detection['label']}:") |
|
sub_results = search_similar_products(temp_path, index, metadata_df, top_k) |
|
|
|
logging.info(f"\nTop {top_k} similar products for {detection['label']}:") |
|
for j, result in enumerate(sub_results, 1): |
|
logging.info(f"\n{j}. Similarity score: {1 / (1 + result['distance']):.3f}") |
|
logging.info(f"Product URL: {result['product_url']}") |
|
logging.info(f"Image URL: {result['image_url']}") |
|
|
|
except Exception as e: |
|
logging.error(f"Error processing detection {i+1}: {e}") |
|
|
|
finally: |
|
|
|
if os.path.exists(temp_path): |
|
logging.info(f"Cleaning up temporary file: {temp_path}") |
|
os.remove(temp_path) |
|
|
|
return results |
|
|
|
|
|
def generate_image(prompt, control_image, guidance_scale, output_quality, negative_prompt, control_strength, image_to_image_strength, control_type): |
|
logging.info("\nGenerating image with Stable Diffusion...") |
|
logging.info(f"Prompt: {prompt}") |
|
|
|
if isinstance(control_image, str): |
|
image = open(control_image, "rb") |
|
else: |
|
image = control_image |
|
|
|
input = { |
|
"prompt": prompt, |
|
"control_image": image, |
|
"guidance_scale": guidance_scale, |
|
"output_quality": output_quality, |
|
"negative_prompt": negative_prompt, |
|
"control_strength": control_strength, |
|
"image_to_image_strength": image_to_image_strength, |
|
"control_type": control_type |
|
} |
|
|
|
logging.info("Running image generation model...") |
|
output = replicate.run( |
|
"xlabs-ai/flux-dev-controlnet:9a8db105db745f8b11ad3afe5c8bd892428b2a43ade0b67edc4e0ccd52ff2fda", |
|
input=input |
|
) |
|
logging.info("Saving generated images...") |
|
for index, item in enumerate(output): |
|
with open(f"output_{index}.jpg", "wb") as file: |
|
file.write(item.read()) |
|
logging.info(f"Saved output_{index}.jpg") |
|
return output |
|
|
|
|
|
|
|
def analyze_image(image_path): |
|
logging.info(f"\nAnalyzing image with GPT-4V: {image_path}") |
|
|
|
class ImageAnalysis(BaseModel): |
|
objects: list[str] = Field(description="A list of objects in the image") |
|
|
|
|
|
def encode_image(image_path): |
|
logging.info("Encoding image to base64...") |
|
with open(image_path, "rb") as image_file: |
|
return base64.b64encode(image_file.read()).decode('utf-8') |
|
|
|
|
|
encoded_image = encode_image(image_path) |
|
|
|
logging.info("Sending request to GPT-4o-mini vision...") |
|
completion = client.beta.chat.completions.parse( |
|
model="gpt-4o-mini", |
|
messages=[ |
|
{ |
|
"role": "user", |
|
"content": [ |
|
{"type": "text", "text": "Analyze this image and list the main objects of furniture in the image."}, |
|
{ |
|
"type": "image_url", |
|
"image_url": { |
|
"url": f"data:image/jpeg;base64,{encoded_image}" |
|
}, |
|
}, |
|
], |
|
} |
|
], |
|
response_format=ImageAnalysis, |
|
) |
|
|
|
analysis = completion.choices[0].message.parsed |
|
main_objects = ', '.join(analysis.objects) |
|
logging.info(f""" |
|
Objects: {', '.join(analysis.objects)} |
|
""") |
|
return main_objects |
|
|
|
|
|
|
|
def detect_objects(image_path, main_objects): |
|
logging.info(f"\nDetecting objects in image: {image_path}") |
|
image = open(image_path, "rb") |
|
input = { |
|
"image": image, |
|
"query": main_objects, |
|
"box_threshold": 0.2, |
|
"text_threshold": 0.2 |
|
} |
|
|
|
logging.info("Running object detection model...") |
|
output = replicate.run( |
|
"adirik/grounding-dino:efd10a8ddc57ea28773327e881ce95e20cc1d734c589f7dd01d2036921ed78aa", |
|
input=input |
|
) |
|
logging.info("Detection results:") |
|
logging.info(output) |
|
return output |
|
|
|
|
|
def search_index(image_path, index, metadata_df, main_objects = None, top_k=5): |
|
logging.info(f"\nSearching index for similar products to: {image_path}") |
|
|
|
results = find_similar_ikea_products(image_path, index, metadata_df, top_k=5) |
|
return results |
|
|
|
|
|
|
|
|
|
def main(prompt, control_image, index, metadata_df): |
|
""" |
|
Main function to orchestrate the entire inference pipeline |
|
|
|
Args: |
|
prompt (str): Text prompt for image generation |
|
control_image: Input image for controlled generation |
|
index: FAISS index for similarity search |
|
metadata_df: DataFrame containing product metadata |
|
|
|
Returns: |
|
dict: Results containing generated images, detected objects, and similar products |
|
""" |
|
|
|
|
|
logging.info("\nStarting inference pipeline...") |
|
results = {} |
|
|
|
logging.info("\nStep 0: Generating a detailed prompt for the diffusion model...") |
|
|
|
prompt = generate_prompt_for_flux(prompt) |
|
|
|
prompt += ", realistic, high quality, 8K, photorealistic, high detail, sharp focus" |
|
|
|
logging.info(f"\nGenerated prompt: {prompt}") |
|
|
|
logging.info("\nStep 1: Generating image...") |
|
|
|
generated_images = generate_image( |
|
prompt=prompt, |
|
control_image=control_image, |
|
guidance_scale=2.5, |
|
output_quality=100, |
|
negative_prompt="low quality, ugly, distorted, artefacts, low detail, low quality, low resolution, low definition, imaginary, unrealistic, fictional", |
|
control_strength=0.5, |
|
image_to_image_strength=0.1, |
|
control_type="canny" |
|
) |
|
results['generated_images'] = generated_images |
|
results['generated_image_path'] = "output_0.jpg" |
|
|
|
|
|
|
|
obj_detection = False |
|
if obj_detection: |
|
main_objects = analyze_image("output_0.jpg") |
|
results['detected_furniture'] = main_objects |
|
|
|
logging.info("\nSteps 3-5: Detecting objects and searching for similar products...") |
|
|
|
similar_products = search_index( |
|
image_path="output_0.jpg", |
|
index=index, |
|
metadata_df=metadata_df, |
|
top_k=5 |
|
) |
|
results['similar_products'] = similar_products |
|
|
|
return results |
|
|
|
def load_index(): |
|
logging.info("\nLoading FAISS index...") |
|
return faiss.read_index("data/ikea_faiss.index") |
|
|
|
def load_metadata(): |
|
logging.info("Loading metadata...") |
|
return pd.read_csv("data/filtered_metadata.csv") |
|
|
|
if __name__ == "__main__": |
|
|
|
logging.info("\nStarting program...") |
|
|
|
if len(sys.argv) != 3: |
|
logging.error("Usage: python inference.py <prompt> <link to control image>") |
|
sys.exit(1) |
|
|
|
prompt = sys.argv[1] |
|
control_image = sys.argv[2] |
|
|
|
logging.info("\nLoading required data...") |
|
|
|
index = load_index() |
|
metadata_df = load_metadata() |
|
|
|
results = main(prompt, control_image, index, metadata_df) |
|
logging.info("\nPipeline completed successfully!") |