from fastapi import FastAPI, UploadFile, File from fastapi.middleware.cors import CORSMiddleware # Loading import os from datasets import load_dataset import torch from tqdm import tqdm from sentence_transformers import SentenceTransformer import uuid from qdrant_client import models, QdrantClient from itertools import islice # Create function to upsert embeddings in batches def batched(iterable, n): iterator = iter(iterable) while batch := list(islice(iterator, n)): yield batch batch_size = 100 # Create an in-memory Qdrant instance client2 = QdrantClient(path ="database.db") # Create a Qdrant collection for the embeddings client2.create_collection( collection_name="law", vectors_config=models.VectorParams( size=model.get_sentence_embedding_dimension(), distance=models.Distance.COSINE, ), ) # Determine device based on GPU availability device = "cuda" if torch.cuda.is_available() else "cpu" print(f"Using device: {device}") FILEPATH_PATTERN = "structured_data_doc.parquet" CACHE_DIR = "/.cache" NUM_PROC = os.cpu_count() app = FastAPI() # Load the desired model model = SentenceTransformer( 'sentence-transformers/all-MiniLM-L6-v2', device=device ) # Create function to generate embeddings (in batches) for a given dataset split def generate_embeddings(dataset, batch_size=32): embeddings = [] with tqdm(total=len(dataset), desc=f"Generating embeddings for dataset") as pbar: for i in range(0, len(dataset), batch_size): batch_sentences = dataset['content'][i:i+batch_size] batch_embeddings = model.encode(batch_sentences) embeddings.extend(batch_embeddings) pbar.update(len(batch_sentences)) return embeddings @app.post("/uploadfile/") async def create_upload_file(file: UploadFile = File(...)): # Here you can save the file and do other operations as needed full_dataset = load_dataset("parquet", data_files=FILEPATH_PATTERN, split="train", keep_in_memory=True, cache_dir=CACHE_DIR, num_proc=NUM_PROC*2) # Generate and append embeddings to the train split law_embeddings = generate_embeddings(full_dataset) full_dataset= full_dataset.add_column("embeddings", law_embeddings) if not 'uuid' in full_dataset.column_names: full_dataset = full_dataset.add_column('uuid', [str(uuid.uuid4()) for _ in range(len(full_dataset))]) # Upsert the embeddings in batches for batch in batched(full_dataset, batch_size): ids = [point.pop("uuid") for point in batch] vectors = [point.pop("embeddings") for point in batch] client2.upsert( collection_name="law", points=models.Batch( ids=ids, vectors=vectors, payloads=batch, ), ) return {"filename": file.filename, "message": "Done"} app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) @app.get("/search") def search(prompt: str): # Let's see what senators are saying about immigration policy hits = client2.search( collection_name="law", query_vector=model.encode(prompt).tolist(), limit=5 ) for hit in hits: print(hit.payload, "score:", hit.score) return {'detail': 'hit.payload', 'score:', hit.score} @app.get("/") def api_home(): return {'detail': 'Welcome to FastAPI Qdrant importer!'}