adel67460 commited on
Commit
d12b30b
·
verified ·
1 Parent(s): c968d12

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +41 -14
app.py CHANGED
@@ -1,23 +1,47 @@
1
  import gradio as gr
2
  import json
3
  import pandas as pd
 
 
 
 
4
 
5
  # Charger les produits depuis le fichier JSON
6
  with open("products.json", "r", encoding="utf-8") as file:
7
  data = json.load(file)["products"]
8
 
9
- # Convertir les données en DataFrame Pandas
10
  df = pd.DataFrame(data)
11
 
12
- # Fonction de recherche avec filtres
 
 
 
 
 
 
 
 
 
 
 
 
13
  def search_products(query, category, min_price, max_price):
14
  results = df.copy()
15
 
16
- # Filtrer selon la recherche textuelle
17
- if query:
18
- results = results[results["title"].str.contains(query, case=False, na=False)]
 
 
 
 
 
 
 
 
19
 
20
- # Filtrer par catégorie
21
  if category and category != "Toutes":
22
  results = results[results["product_type"].str.contains(category, case=False, na=False)]
23
 
@@ -25,27 +49,30 @@ def search_products(query, category, min_price, max_price):
25
  results["price"] = results["price"].str.replace(" EUR", "").astype(float)
26
  results = results[(results["price"] >= min_price) & (results["price"] <= max_price)]
27
 
 
 
 
28
  # Générer l'affichage des résultats
29
  if results.empty:
30
  return "Aucun produit ne correspond à votre recherche."
31
-
32
  products_list = []
33
- for _, row in results.iterrows():
34
  products_list.append(f"🔹 **{row['title']}**\n💰 {row['price']} EUR\n🔗 [Voir le produit]({row['link']})")
35
-
36
  return "\n\n".join(products_list)
37
 
38
- # Interface utilisateur avec Gradio
39
  with gr.Blocks() as demo:
40
- gr.Markdown("# 🔍 Recherche avancée Straburo")
41
 
42
  with gr.Row():
43
  query_input = gr.Textbox(label="Rechercher un produit")
44
  category_input = gr.Dropdown(["Toutes"] + list(df["product_type"].dropna().unique()), label="Catégorie")
45
-
46
  with gr.Row():
47
- min_price = gr.Slider(0, df["price"].str.replace(" EUR", "").astype(float).max(), value=0, label="Prix min (€)")
48
- max_price = gr.Slider(0, df["price"].str.replace(" EUR", "").astype(float).max(), value=500, label="Prix max (€)")
49
 
50
  search_button = gr.Button("🔍 Rechercher")
51
  results_output = gr.Markdown()
 
1
  import gradio as gr
2
  import json
3
  import pandas as pd
4
+ import torch
5
+ from transformers import AutoModel, AutoProcessor
6
+ from PIL import Image
7
+ import requests
8
 
9
  # Charger les produits depuis le fichier JSON
10
  with open("products.json", "r", encoding="utf-8") as file:
11
  data = json.load(file)["products"]
12
 
13
+ # Convertir en DataFrame Pandas
14
  df = pd.DataFrame(data)
15
 
16
+ # Charger le modèle Marqo-Ecommerce-Embeddings-L
17
+ model_name = "Marqo/marqo-ecommerce-embeddings-L"
18
+ model = AutoModel.from_pretrained(model_name, trust_remote_code=True)
19
+ processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=True)
20
+
21
+ # Fonction pour extraire des embeddings textuels
22
+ def get_text_embedding(text):
23
+ processed = processor(text=[text], return_tensors="pt")
24
+ with torch.no_grad():
25
+ embedding = model.get_text_features(processed["input_ids"], normalize=True)
26
+ return embedding
27
+
28
+ # Fonction de recherche avancée avec Marqo
29
  def search_products(query, category, min_price, max_price):
30
  results = df.copy()
31
 
32
+ # Générer l'embedding de la requête
33
+ query_embedding = get_text_embedding(query)
34
+
35
+ # Générer les embeddings des titres de produits
36
+ product_embeddings = torch.stack([get_text_embedding(title) for title in df["title"]])
37
+
38
+ # Calculer les similarités cosinus
39
+ similarities = (query_embedding @ product_embeddings.T).softmax(dim=-1).squeeze()
40
+
41
+ # Ajouter les scores de similarité au DataFrame
42
+ results["score"] = similarities.numpy()
43
 
44
+ # Filtrer par catégorie si sélectionnée
45
  if category and category != "Toutes":
46
  results = results[results["product_type"].str.contains(category, case=False, na=False)]
47
 
 
49
  results["price"] = results["price"].str.replace(" EUR", "").astype(float)
50
  results = results[(results["price"] >= min_price) & (results["price"] <= max_price)]
51
 
52
+ # Trier par score de pertinence
53
+ results = results.sort_values(by="score", ascending=False)
54
+
55
  # Générer l'affichage des résultats
56
  if results.empty:
57
  return "Aucun produit ne correspond à votre recherche."
58
+
59
  products_list = []
60
+ for _, row in results.head(10).iterrows():
61
  products_list.append(f"🔹 **{row['title']}**\n💰 {row['price']} EUR\n🔗 [Voir le produit]({row['link']})")
62
+
63
  return "\n\n".join(products_list)
64
 
65
+ # Interface utilisateur Gradio
66
  with gr.Blocks() as demo:
67
+ gr.Markdown("# 🔍 Recherche avancée Straburo (IA)")
68
 
69
  with gr.Row():
70
  query_input = gr.Textbox(label="Rechercher un produit")
71
  category_input = gr.Dropdown(["Toutes"] + list(df["product_type"].dropna().unique()), label="Catégorie")
72
+
73
  with gr.Row():
74
+ min_price = gr.Slider(0, df["price"].max(), value=0, label="Prix min (€)")
75
+ max_price = gr.Slider(0, df["price"].max(), value=500, label="Prix max (€)")
76
 
77
  search_button = gr.Button("🔍 Rechercher")
78
  results_output = gr.Markdown()