Spaces:
Sleeping
Sleeping
from fastapi import FastAPI, Query | |
from pydantic import BaseModel | |
from sentence_transformers import SentenceTransformer, util | |
from datasets import load_dataset | |
from typing import List | |
import numpy as np | |
import base64 | |
from PIL import Image | |
from io import BytesIO | |
app = FastAPI() | |
# Load Dataset | |
dataset = load_dataset("MohamedAshraf701/medicine-dataset", split="train") | |
# Limit the dataset to 30,000 entries | |
# Define fields for embedding | |
fields_for_embedding = [ | |
"product_name", | |
"sub_category", | |
"salt_composition", | |
"product_manufactured", | |
"medicine_desc", | |
"side_effects", | |
"drug", | |
"brand", | |
"effect" | |
] | |
# Load Sentence Transformer Model | |
model = SentenceTransformer("sentence-transformers/multi-qa-MiniLM-L6-cos-v1") | |
# Generate Embeddings | |
def create_combined_text(item): | |
""" | |
Combines fields from an item into a single string for embedding, | |
converting arrays to comma-separated strings where necessary. | |
""" | |
combined_text = [] | |
for field in fields_for_embedding: | |
value = item.get(field) | |
if value: | |
# If the field is a list, join its elements into a single string | |
if isinstance(value, list): | |
combined_text.append(", ".join(map(str, value))) | |
else: | |
combined_text.append(str(value)) | |
return " ".join(combined_text) | |
embeddings = dataset["embeddings"] | |
def root(): | |
# Prepare Data | |
return {"message": "Welcome to the medicine Search API!"} | |
def search_products( | |
query: str = Query("", title="Search Query", description="Search term for medicine"), | |
page: int = Query(1, ge=1, title="Page Number"), | |
items_per_page: int = Query(10, ge=1, le=100, title="Items Per Page"), | |
): | |
# Perform Search | |
if query: | |
query_embedding = model.encode(query, convert_to_tensor=True) | |
scores = util.cos_sim(query_embedding, embeddings).squeeze().tolist() | |
ranked_indices = np.argsort(scores)[::-1] | |
else: | |
ranked_indices = np.arange(len(dataset)) | |
# Pagination | |
total_items = len(ranked_indices) | |
total_pages = (total_items + items_per_page - 1) // items_per_page | |
start_idx = (page - 1) * items_per_page | |
end_idx = start_idx + items_per_page | |
paginated_indices = ranked_indices[start_idx:end_idx] | |
# Prepare Response using select() | |
paginated_dataset = dataset.select(paginated_indices) | |
# Exclude the 'embeddings' column | |
results = [ | |
{key: value for key, value in item.items() if key != "embeddings"} | |
for item in paginated_dataset | |
] | |
# Construct the API response | |
return { | |
"status": 200, | |
"data": results, | |
"totalpages": total_pages, | |
"currentpage": page | |
} |