Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -1,23 +1,47 @@
|
|
1 |
import gradio as gr
|
2 |
import json
|
3 |
import pandas as pd
|
|
|
|
|
|
|
|
|
4 |
|
5 |
# Charger les produits depuis le fichier JSON
|
6 |
with open("products.json", "r", encoding="utf-8") as file:
|
7 |
data = json.load(file)["products"]
|
8 |
|
9 |
-
# Convertir
|
10 |
df = pd.DataFrame(data)
|
11 |
|
12 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
13 |
def search_products(query, category, min_price, max_price):
|
14 |
results = df.copy()
|
15 |
|
16 |
-
#
|
17 |
-
|
18 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
19 |
|
20 |
-
# Filtrer par catégorie
|
21 |
if category and category != "Toutes":
|
22 |
results = results[results["product_type"].str.contains(category, case=False, na=False)]
|
23 |
|
@@ -25,27 +49,30 @@ def search_products(query, category, min_price, max_price):
|
|
25 |
results["price"] = results["price"].str.replace(" EUR", "").astype(float)
|
26 |
results = results[(results["price"] >= min_price) & (results["price"] <= max_price)]
|
27 |
|
|
|
|
|
|
|
28 |
# Générer l'affichage des résultats
|
29 |
if results.empty:
|
30 |
return "Aucun produit ne correspond à votre recherche."
|
31 |
-
|
32 |
products_list = []
|
33 |
-
for _, row in results.iterrows():
|
34 |
products_list.append(f"🔹 **{row['title']}**\n💰 {row['price']} EUR\n🔗 [Voir le produit]({row['link']})")
|
35 |
-
|
36 |
return "\n\n".join(products_list)
|
37 |
|
38 |
-
# Interface utilisateur
|
39 |
with gr.Blocks() as demo:
|
40 |
-
gr.Markdown("# 🔍 Recherche avancée Straburo")
|
41 |
|
42 |
with gr.Row():
|
43 |
query_input = gr.Textbox(label="Rechercher un produit")
|
44 |
category_input = gr.Dropdown(["Toutes"] + list(df["product_type"].dropna().unique()), label="Catégorie")
|
45 |
-
|
46 |
with gr.Row():
|
47 |
-
min_price = gr.Slider(0, df["price"].
|
48 |
-
max_price = gr.Slider(0, df["price"].
|
49 |
|
50 |
search_button = gr.Button("🔍 Rechercher")
|
51 |
results_output = gr.Markdown()
|
|
|
1 |
import gradio as gr
|
2 |
import json
|
3 |
import pandas as pd
|
4 |
+
import torch
|
5 |
+
from transformers import AutoModel, AutoProcessor
|
6 |
+
from PIL import Image
|
7 |
+
import requests
|
8 |
|
9 |
# Charger les produits depuis le fichier JSON
|
10 |
with open("products.json", "r", encoding="utf-8") as file:
|
11 |
data = json.load(file)["products"]
|
12 |
|
13 |
+
# Convertir en DataFrame Pandas
|
14 |
df = pd.DataFrame(data)
|
15 |
|
16 |
+
# Charger le modèle Marqo-Ecommerce-Embeddings-L
|
17 |
+
model_name = "Marqo/marqo-ecommerce-embeddings-L"
|
18 |
+
model = AutoModel.from_pretrained(model_name, trust_remote_code=True)
|
19 |
+
processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=True)
|
20 |
+
|
21 |
+
# Fonction pour extraire des embeddings textuels
|
22 |
+
def get_text_embedding(text):
|
23 |
+
processed = processor(text=[text], return_tensors="pt")
|
24 |
+
with torch.no_grad():
|
25 |
+
embedding = model.get_text_features(processed["input_ids"], normalize=True)
|
26 |
+
return embedding
|
27 |
+
|
28 |
+
# Fonction de recherche avancée avec Marqo
|
29 |
def search_products(query, category, min_price, max_price):
|
30 |
results = df.copy()
|
31 |
|
32 |
+
# Générer l'embedding de la requête
|
33 |
+
query_embedding = get_text_embedding(query)
|
34 |
+
|
35 |
+
# Générer les embeddings des titres de produits
|
36 |
+
product_embeddings = torch.stack([get_text_embedding(title) for title in df["title"]])
|
37 |
+
|
38 |
+
# Calculer les similarités cosinus
|
39 |
+
similarities = (query_embedding @ product_embeddings.T).softmax(dim=-1).squeeze()
|
40 |
+
|
41 |
+
# Ajouter les scores de similarité au DataFrame
|
42 |
+
results["score"] = similarities.numpy()
|
43 |
|
44 |
+
# Filtrer par catégorie si sélectionnée
|
45 |
if category and category != "Toutes":
|
46 |
results = results[results["product_type"].str.contains(category, case=False, na=False)]
|
47 |
|
|
|
49 |
results["price"] = results["price"].str.replace(" EUR", "").astype(float)
|
50 |
results = results[(results["price"] >= min_price) & (results["price"] <= max_price)]
|
51 |
|
52 |
+
# Trier par score de pertinence
|
53 |
+
results = results.sort_values(by="score", ascending=False)
|
54 |
+
|
55 |
# Générer l'affichage des résultats
|
56 |
if results.empty:
|
57 |
return "Aucun produit ne correspond à votre recherche."
|
58 |
+
|
59 |
products_list = []
|
60 |
+
for _, row in results.head(10).iterrows():
|
61 |
products_list.append(f"🔹 **{row['title']}**\n💰 {row['price']} EUR\n🔗 [Voir le produit]({row['link']})")
|
62 |
+
|
63 |
return "\n\n".join(products_list)
|
64 |
|
65 |
+
# Interface utilisateur Gradio
|
66 |
with gr.Blocks() as demo:
|
67 |
+
gr.Markdown("# 🔍 Recherche avancée Straburo (IA)")
|
68 |
|
69 |
with gr.Row():
|
70 |
query_input = gr.Textbox(label="Rechercher un produit")
|
71 |
category_input = gr.Dropdown(["Toutes"] + list(df["product_type"].dropna().unique()), label="Catégorie")
|
72 |
+
|
73 |
with gr.Row():
|
74 |
+
min_price = gr.Slider(0, df["price"].max(), value=0, label="Prix min (€)")
|
75 |
+
max_price = gr.Slider(0, df["price"].max(), value=500, label="Prix max (€)")
|
76 |
|
77 |
search_button = gr.Button("🔍 Rechercher")
|
78 |
results_output = gr.Markdown()
|