Ashraf-CK's picture
Update app.py
2f42a1c verified
from fastapi import FastAPI, Query
from pydantic import BaseModel
from sentence_transformers import SentenceTransformer, util
from datasets import load_dataset
from typing import List
import numpy as np
import base64
from PIL import Image
from io import BytesIO
app = FastAPI()
# Load Dataset
dataset = load_dataset("MohamedAshraf701/medicine-dataset", split="train")
# Limit the dataset to 30,000 entries
# Define fields for embedding
fields_for_embedding = [
"product_name",
"sub_category",
"salt_composition",
"product_manufactured",
"medicine_desc",
"side_effects",
"drug",
"brand",
"effect"
]
# Load Sentence Transformer Model
model = SentenceTransformer("sentence-transformers/multi-qa-MiniLM-L6-cos-v1")
# Generate Embeddings
def create_combined_text(item):
"""
Combines fields from an item into a single string for embedding,
converting arrays to comma-separated strings where necessary.
"""
combined_text = []
for field in fields_for_embedding:
value = item.get(field)
if value:
# If the field is a list, join its elements into a single string
if isinstance(value, list):
combined_text.append(", ".join(map(str, value)))
else:
combined_text.append(str(value))
return " ".join(combined_text)
embeddings = dataset["embeddings"]
@app.get("/gen")
def root():
# Prepare Data
return {"message": "Welcome to the medicine Search API!"}
@app.get("/meds")
def search_products(
query: str = Query("", title="Search Query", description="Search term for medicine"),
page: int = Query(1, ge=1, title="Page Number"),
items_per_page: int = Query(10, ge=1, le=100, title="Items Per Page"),
):
# Perform Search
if query:
query_embedding = model.encode(query, convert_to_tensor=True)
scores = util.cos_sim(query_embedding, embeddings).squeeze().tolist()
ranked_indices = np.argsort(scores)[::-1]
else:
ranked_indices = np.arange(len(dataset))
# Pagination
total_items = len(ranked_indices)
total_pages = (total_items + items_per_page - 1) // items_per_page
start_idx = (page - 1) * items_per_page
end_idx = start_idx + items_per_page
paginated_indices = ranked_indices[start_idx:end_idx]
# Prepare Response using select()
paginated_dataset = dataset.select(paginated_indices)
# Exclude the 'embeddings' column
results = [
{key: value for key, value in item.items() if key != "embeddings"}
for item in paginated_dataset
]
# Construct the API response
return {
"status": 200,
"data": results,
"totalpages": total_pages,
"currentpage": page
}