from fastapi import FastAPI, Query from pydantic import BaseModel from sentence_transformers import SentenceTransformer, util from datasets import load_dataset from typing import List import numpy as np import base64 from PIL import Image from io import BytesIO app = FastAPI() # Load Dataset dataset = load_dataset("MohamedAshraf701/medicine-dataset", split="train") # Limit the dataset to 30,000 entries # Define fields for embedding fields_for_embedding = [ "product_name", "sub_category", "salt_composition", "product_manufactured", "medicine_desc", "side_effects", "drug", "brand", "effect" ] # Load Sentence Transformer Model model = SentenceTransformer("sentence-transformers/multi-qa-MiniLM-L6-cos-v1") # Generate Embeddings def create_combined_text(item): """ Combines fields from an item into a single string for embedding, converting arrays to comma-separated strings where necessary. """ combined_text = [] for field in fields_for_embedding: value = item.get(field) if value: # If the field is a list, join its elements into a single string if isinstance(value, list): combined_text.append(", ".join(map(str, value))) else: combined_text.append(str(value)) return " ".join(combined_text) embeddings = dataset["embeddings"] @app.get("/gen") def root(): # Prepare Data return {"message": "Welcome to the medicine Search API!"} @app.get("/meds") def search_products( query: str = Query("", title="Search Query", description="Search term for medicine"), page: int = Query(1, ge=1, title="Page Number"), items_per_page: int = Query(10, ge=1, le=100, title="Items Per Page"), ): # Perform Search if query: query_embedding = model.encode(query, convert_to_tensor=True) scores = util.cos_sim(query_embedding, embeddings).squeeze().tolist() ranked_indices = np.argsort(scores)[::-1] else: ranked_indices = np.arange(len(dataset)) # Pagination total_items = len(ranked_indices) total_pages = (total_items + items_per_page - 1) // items_per_page start_idx = (page - 1) * items_per_page end_idx = start_idx + items_per_page paginated_indices = ranked_indices[start_idx:end_idx] # Prepare Response using select() paginated_dataset = dataset.select(paginated_indices) # Exclude the 'embeddings' column results = [ {key: value for key, value in item.items() if key != "embeddings"} for item in paginated_dataset ] # Construct the API response return { "status": 200, "data": results, "totalpages": total_pages, "currentpage": page }