Spaces:
Sleeping
Sleeping
from fastapi import FastAPI, UploadFile, File | |
from fastapi.responses import FileResponse | |
from datasets import load_dataset | |
from fastapi.middleware.cors import CORSMiddleware | |
# Loading | |
import os | |
import shutil | |
from os import makedirs,getcwd | |
from os.path import join,exists,dirname | |
import torch | |
app = FastAPI() | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
NUM_PROC = os.cpu_count() | |
parent_path = dirname(getcwd()) | |
temp_path = join(parent_path,'temp') | |
if not exists(temp_path ): | |
makedirs(temp_path ) | |
# Determine device based on GPU availability | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
print(f"Using device: {device}") | |
import logging | |
logging.basicConfig(format="%(levelname)s - %(name)s - %(message)s", level=logging.WARNING) | |
logging.getLogger("haystack").setLevel(logging.INFO) | |
async def create_upload_file(text_field: str, file: UploadFile = File(...)): | |
# Imports | |
import time | |
from haystack import Document, Pipeline | |
from haystack.components.writers import DocumentWriter | |
from haystack_integrations.components.retrievers.qdrant import QdrantHybridRetriever | |
from haystack_integrations.document_stores.qdrant import QdrantDocumentStore | |
from haystack.document_stores.types import DuplicatePolicy | |
from haystack_integrations.components.embedders.fastembed import ( | |
FastembedTextEmbedder, | |
FastembedDocumentEmbedder, | |
FastembedSparseTextEmbedder, | |
FastembedSparseDocumentEmbedder | |
) | |
start_time = time.time() | |
file_savePath = join(temp_path,file.filename) | |
with open(file_savePath,'wb') as f: | |
shutil.copyfileobj(file.file, f) | |
documents=[] | |
# Here you can save the file and do other operations as needed | |
if '.json' in file_savePath: | |
with open(file_savePath) as fd: | |
for line in fd: | |
obj = json.loads(line) | |
document = Document(content=obj[text_field], meta=obj) | |
documents.append(document) | |
else: | |
raise NotImplementedError("This feature is not supported yet") | |
# Indexing | |
document_store = QdrantDocumentStore( | |
path="database", | |
recreate_index=True, | |
use_sparse_embeddings=True, | |
embedding_dim = 384 | |
) | |
indexing = Pipeline() | |
indexing.add_component("sparse_doc_embedder", FastembedSparseDocumentEmbedder(model="prithvida/Splade_PP_en_v1")) | |
indexing.add_component("dense_doc_embedder", FastembedDocumentEmbedder(model="BAAI/bge-small-en-v1.5")) | |
indexing.add_component("writer", DocumentWriter(document_store=document_store, policy=DuplicatePolicy.OVERWRITE)) | |
indexing.connect("sparse_doc_embedder", "dense_doc_embedder") | |
indexing.connect("dense_doc_embedder", "writer") | |
indexing.run({"sparse_doc_embedder": {"documents": documents}}) | |
end_time = time.time() | |
elapsed_time = end_time - start_time | |
return {"filename": file.filename, "message": "Done", "execution_time": elapsed_time} | |
def search(prompt: str): | |
import time | |
start_time = time.time() | |
# Querying | |
querying = Pipeline() | |
querying.add_component("sparse_text_embedder", FastembedSparseTextEmbedder(model="prithvida/Splade_PP_en_v1")) | |
querying.add_component("dense_text_embedder", FastembedTextEmbedder( | |
model="BAAI/bge-small-en-v1.5", prefix="Represent this sentence for searching relevant passages: ") | |
) | |
querying.add_component("retriever", QdrantHybridRetriever(document_store=document_store)) | |
querying.connect("sparse_text_embedder.sparse_embedding", "retriever.query_sparse_embedding") | |
querying.connect("dense_text_embedder.embedding", "retriever.query_embedding") | |
question = "Cosa sono i marker tumorali?" | |
results = querying.run( | |
{"dense_text_embedder": {"text": question}, | |
"sparse_text_embedder": {"text": question}} | |
) | |
end_time = time.time() | |
elapsed_time = end_time - start_time | |
print(f"Execution time: {elapsed_time:.6f} seconds") | |
return results["retriever"]["documents"] | |
async def download_database(): | |
import time | |
start_time = time.time() | |
# Path to the database directory | |
database_dir = join(os.getcwd(), 'database') | |
# Path for the zip file | |
zip_path = join(os.getcwd(), 'database.zip') | |
# Create a zip file of the database directory | |
shutil.make_archive(zip_path.replace('.zip', ''), 'zip', database_dir) | |
end_time = time.time() | |
elapsed_time = end_time - start_time | |
print(f"Execution time: {elapsed_time:.6f} seconds") | |
# Return the zip file as a response for download | |
return FileResponse(zip_path, media_type='application/zip', filename='database.zip') | |
def api_home(): | |
return {'detail': 'Welcome to FastAPI Qdrant importer!'} | |