Spaces:
Sleeping
Sleeping
File size: 4,984 Bytes
a151662 3f50d8b 4beb7b0 fe8dc94 d020550 4beb7b0 cea66bb d020550 4beb7b0 67cd4b7 4beb7b0 d020550 4beb7b0 3f50d8b 4beb7b0 3f50d8b 4beb7b0 c92562c 3f50d8b b2ba730 3f50d8b b2ba730 fe8dc94 4beb7b0 3f50d8b 67cd4b7 3f50d8b 67cd4b7 3f50d8b 67cd4b7 4beb7b0 67cd4b7 3f50d8b 67cd4b7 4beb7b0 67cd4b7 3f50d8b 67cd4b7 3f50d8b 67cd4b7 3f50d8b 22aa66a 67cd4b7 22aa66a 67cd4b7 22aa66a d020550 4beb7b0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 |
from fastapi import FastAPI, UploadFile, File
from fastapi.responses import FileResponse
from datasets import load_dataset
from fastapi.middleware.cors import CORSMiddleware
# Loading
import os
import shutil
from os import makedirs,getcwd
from os.path import join,exists,dirname
import torch
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
NUM_PROC = os.cpu_count()
parent_path = dirname(getcwd())
temp_path = join(parent_path,'temp')
if not exists(temp_path ):
makedirs(temp_path )
# Determine device based on GPU availability
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
import logging
logging.basicConfig(format="%(levelname)s - %(name)s - %(message)s", level=logging.WARNING)
logging.getLogger("haystack").setLevel(logging.INFO)
@app.post("/uploadfile/")
async def create_upload_file(text_field: str, file: UploadFile = File(...)):
# Imports
import time
from haystack import Document, Pipeline
from haystack.components.writers import DocumentWriter
from haystack_integrations.components.retrievers.qdrant import QdrantHybridRetriever
from haystack_integrations.document_stores.qdrant import QdrantDocumentStore
from haystack.document_stores.types import DuplicatePolicy
from haystack_integrations.components.embedders.fastembed import (
FastembedTextEmbedder,
FastembedDocumentEmbedder,
FastembedSparseTextEmbedder,
FastembedSparseDocumentEmbedder
)
start_time = time.time()
file_savePath = join(temp_path,file.filename)
with open(file_savePath,'wb') as f:
shutil.copyfileobj(file.file, f)
documents=[]
# Here you can save the file and do other operations as needed
if '.json' in file_savePath:
with open(file_savePath) as fd:
for line in fd:
obj = json.loads(line)
document = Document(content=obj[text_field], meta=obj)
documents.append(document)
else:
raise NotImplementedError("This feature is not supported yet")
# Indexing
document_store = QdrantDocumentStore(
path="database",
recreate_index=True,
use_sparse_embeddings=True,
embedding_dim = 384
)
indexing = Pipeline()
indexing.add_component("sparse_doc_embedder", FastembedSparseDocumentEmbedder(model="prithvida/Splade_PP_en_v1"))
indexing.add_component("dense_doc_embedder", FastembedDocumentEmbedder(model="BAAI/bge-small-en-v1.5"))
indexing.add_component("writer", DocumentWriter(document_store=document_store, policy=DuplicatePolicy.OVERWRITE))
indexing.connect("sparse_doc_embedder", "dense_doc_embedder")
indexing.connect("dense_doc_embedder", "writer")
indexing.run({"sparse_doc_embedder": {"documents": documents}})
end_time = time.time()
elapsed_time = end_time - start_time
return {"filename": file.filename, "message": "Done", "execution_time": elapsed_time}
@app.get("/search")
def search(prompt: str):
import time
start_time = time.time()
# Querying
querying = Pipeline()
querying.add_component("sparse_text_embedder", FastembedSparseTextEmbedder(model="prithvida/Splade_PP_en_v1"))
querying.add_component("dense_text_embedder", FastembedTextEmbedder(
model="BAAI/bge-small-en-v1.5", prefix="Represent this sentence for searching relevant passages: ")
)
querying.add_component("retriever", QdrantHybridRetriever(document_store=document_store))
querying.connect("sparse_text_embedder.sparse_embedding", "retriever.query_sparse_embedding")
querying.connect("dense_text_embedder.embedding", "retriever.query_embedding")
question = "Cosa sono i marker tumorali?"
results = querying.run(
{"dense_text_embedder": {"text": question},
"sparse_text_embedder": {"text": question}}
)
end_time = time.time()
elapsed_time = end_time - start_time
print(f"Execution time: {elapsed_time:.6f} seconds")
return results["retriever"]["documents"]
@app.get("/download-database/")
async def download_database():
import time
start_time = time.time()
# Path to the database directory
database_dir = join(os.getcwd(), 'database')
# Path for the zip file
zip_path = join(os.getcwd(), 'database.zip')
# Create a zip file of the database directory
shutil.make_archive(zip_path.replace('.zip', ''), 'zip', database_dir)
end_time = time.time()
elapsed_time = end_time - start_time
print(f"Execution time: {elapsed_time:.6f} seconds")
# Return the zip file as a response for download
return FileResponse(zip_path, media_type='application/zip', filename='database.zip')
@app.get("/")
def api_home():
return {'detail': 'Welcome to FastAPI Qdrant importer!'}
|