QDrantRAG9 / app.py
dinhquangson's picture
Update app.py
3f50d8b verified
raw
history blame
4.98 kB
from fastapi import FastAPI, UploadFile, File
from fastapi.responses import FileResponse
from datasets import load_dataset
from fastapi.middleware.cors import CORSMiddleware
# Loading
import os
import shutil
from os import makedirs,getcwd
from os.path import join,exists,dirname
import torch
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
NUM_PROC = os.cpu_count()
parent_path = dirname(getcwd())
temp_path = join(parent_path,'temp')
if not exists(temp_path ):
makedirs(temp_path )
# Determine device based on GPU availability
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
import logging
logging.basicConfig(format="%(levelname)s - %(name)s - %(message)s", level=logging.WARNING)
logging.getLogger("haystack").setLevel(logging.INFO)
@app.post("/uploadfile/")
async def create_upload_file(text_field: str, file: UploadFile = File(...)):
# Imports
import time
from haystack import Document, Pipeline
from haystack.components.writers import DocumentWriter
from haystack_integrations.components.retrievers.qdrant import QdrantHybridRetriever
from haystack_integrations.document_stores.qdrant import QdrantDocumentStore
from haystack.document_stores.types import DuplicatePolicy
from haystack_integrations.components.embedders.fastembed import (
FastembedTextEmbedder,
FastembedDocumentEmbedder,
FastembedSparseTextEmbedder,
FastembedSparseDocumentEmbedder
)
start_time = time.time()
file_savePath = join(temp_path,file.filename)
with open(file_savePath,'wb') as f:
shutil.copyfileobj(file.file, f)
documents=[]
# Here you can save the file and do other operations as needed
if '.json' in file_savePath:
with open(file_savePath) as fd:
for line in fd:
obj = json.loads(line)
document = Document(content=obj[text_field], meta=obj)
documents.append(document)
else:
raise NotImplementedError("This feature is not supported yet")
# Indexing
document_store = QdrantDocumentStore(
path="database",
recreate_index=True,
use_sparse_embeddings=True,
embedding_dim = 384
)
indexing = Pipeline()
indexing.add_component("sparse_doc_embedder", FastembedSparseDocumentEmbedder(model="prithvida/Splade_PP_en_v1"))
indexing.add_component("dense_doc_embedder", FastembedDocumentEmbedder(model="BAAI/bge-small-en-v1.5"))
indexing.add_component("writer", DocumentWriter(document_store=document_store, policy=DuplicatePolicy.OVERWRITE))
indexing.connect("sparse_doc_embedder", "dense_doc_embedder")
indexing.connect("dense_doc_embedder", "writer")
indexing.run({"sparse_doc_embedder": {"documents": documents}})
end_time = time.time()
elapsed_time = end_time - start_time
return {"filename": file.filename, "message": "Done", "execution_time": elapsed_time}
@app.get("/search")
def search(prompt: str):
import time
start_time = time.time()
# Querying
querying = Pipeline()
querying.add_component("sparse_text_embedder", FastembedSparseTextEmbedder(model="prithvida/Splade_PP_en_v1"))
querying.add_component("dense_text_embedder", FastembedTextEmbedder(
model="BAAI/bge-small-en-v1.5", prefix="Represent this sentence for searching relevant passages: ")
)
querying.add_component("retriever", QdrantHybridRetriever(document_store=document_store))
querying.connect("sparse_text_embedder.sparse_embedding", "retriever.query_sparse_embedding")
querying.connect("dense_text_embedder.embedding", "retriever.query_embedding")
question = "Cosa sono i marker tumorali?"
results = querying.run(
{"dense_text_embedder": {"text": question},
"sparse_text_embedder": {"text": question}}
)
end_time = time.time()
elapsed_time = end_time - start_time
print(f"Execution time: {elapsed_time:.6f} seconds")
return results["retriever"]["documents"]
@app.get("/download-database/")
async def download_database():
import time
start_time = time.time()
# Path to the database directory
database_dir = join(os.getcwd(), 'database')
# Path for the zip file
zip_path = join(os.getcwd(), 'database.zip')
# Create a zip file of the database directory
shutil.make_archive(zip_path.replace('.zip', ''), 'zip', database_dir)
end_time = time.time()
elapsed_time = end_time - start_time
print(f"Execution time: {elapsed_time:.6f} seconds")
# Return the zip file as a response for download
return FileResponse(zip_path, media_type='application/zip', filename='database.zip')
@app.get("/")
def api_home():
return {'detail': 'Welcome to FastAPI Qdrant importer!'}