Spaces:
Sleeping
Sleeping
from fastapi import FastAPI, UploadFile, File | |
from fastapi.responses import FileResponse | |
from datasets import load_dataset | |
from fastapi.middleware.cors import CORSMiddleware | |
import pdfplumber | |
import pytesseract | |
from models import Article, Chapter, Law | |
# Loading | |
import os | |
import zipfile | |
import shutil | |
from os import makedirs,getcwd | |
from os.path import join,exists,dirname | |
import torch | |
import json | |
from haystack_integrations.document_stores.qdrant import QdrantDocumentStore | |
app = FastAPI() | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
NUM_PROC = os.cpu_count() | |
parent_path = dirname(getcwd()) | |
temp_path = join(parent_path,'temp') | |
if not exists(temp_path ): | |
makedirs(temp_path ) | |
# Determine device based on GPU availability | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
print(f"Using device: {device}") | |
import logging | |
logging.basicConfig(format="%(levelname)s - %(name)s - %(message)s", level=logging.WARNING) | |
logging.getLogger("haystack").setLevel(logging.INFO) | |
document_store = QdrantDocumentStore( | |
path="database", | |
recreate_index=True, | |
use_sparse_embeddings=True, | |
embedding_dim = 384 | |
) | |
def extract_zip(zip_path, target_folder): | |
""" | |
Extracts all files from a ZIP archive and returns a list of their paths. | |
Args: | |
zip_path (str): Path to the ZIP file. | |
target_folder (str): Folder where the files will be extracted. | |
Returns: | |
List[str]: List of extracted file paths. | |
""" | |
extracted_files = [] | |
with zipfile.ZipFile(zip_path, 'r') as zip_ref: | |
zip_ref.extractall(target_folder) | |
for filename in zip_ref.namelist(): | |
extracted_files.append(os.path.join(target_folder, filename)) | |
return extracted_files | |
def extract_text_from_pdf(pdf_path): | |
with pdfplumber.open(pdf_path) as pdf: | |
text = "" | |
for page in pdf.pages: | |
text += page.extract_text() | |
return text | |
def extract_ocr_text_from_pdf(pdf_path): | |
from pdf2image import convert_from_path | |
images = convert_from_path(pdf_path) | |
text= "" | |
for image in images: | |
text += pytesseract.image_to_string(image,lang='vie') | |
return text | |
async def create_upload_file(text_field: str, file: UploadFile = File(...), ocr:bool=False): | |
# Imports | |
import time | |
from haystack import Document, Pipeline | |
from haystack.components.writers import DocumentWriter | |
from haystack.components.preprocessors import DocumentSplitter, DocumentCleaner | |
from haystack.components.joiners import DocumentJoiner | |
from haystack_integrations.components.retrievers.qdrant import QdrantHybridRetriever | |
from haystack.document_stores.types import DuplicatePolicy | |
from haystack_integrations.components.embedders.fastembed import ( | |
FastembedTextEmbedder, | |
FastembedDocumentEmbedder, | |
FastembedSparseTextEmbedder, | |
FastembedSparseDocumentEmbedder | |
) | |
start_time = time.time() | |
file_savePath = join(temp_path,file.filename) | |
with open(file_savePath,'wb') as f: | |
shutil.copyfileobj(file.file, f) | |
documents=[] | |
# Here you can save the file and do other operations as needed | |
if '.json' in file_savePath: | |
with open(file_savePath) as fd: | |
for line in fd: | |
obj = json.loads(line) | |
document = Document(content=obj[text_field], meta=obj) | |
documents.append(document) | |
elif '.zip' in file_savePath: | |
extracted_files_list = extract_zip(file_savePath, temp_path) | |
print("Extracted files:") | |
for file_path in extracted_files_list: | |
if '.pdf' in file_path: | |
if ocr: | |
text = extract_ocr_text_from_pdf(file_path) | |
else: | |
text = extract_text_from_pdf(file_path) | |
obj = {text_field:text,file_path:file_path} | |
document = Document(content=obj[text_field], meta=obj) | |
documents.append(document) | |
else: | |
raise NotImplementedError("This feature is not supported yet") | |
# Indexing | |
indexing = Pipeline() | |
document_joiner = DocumentJoiner() | |
document_cleaner = DocumentCleaner() | |
document_splitter = DocumentSplitter(split_by="word", split_length=1000, split_overlap=0) | |
indexing.add_component("document_joiner", document_joiner) | |
indexing.add_component("document_cleaner", document_cleaner) | |
indexing.add_component("document_splitter", document_splitter) | |
indexing.add_component("sparse_doc_embedder", FastembedSparseDocumentEmbedder(model="Qdrant/bm42-all-minilm-l6-v2-attentions")) | |
indexing.add_component("dense_doc_embedder", FastembedDocumentEmbedder(model="sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2")) | |
indexing.add_component("writer", DocumentWriter(document_store=document_store, policy=DuplicatePolicy.OVERWRITE)) | |
indexing.connect("document_joiner", "document_cleaner") | |
indexing.connect("document_cleaner", "document_splitter") | |
indexing.connect("document_splitter", "sparse_doc_embedder") | |
indexing.connect("sparse_doc_embedder", "dense_doc_embedder") | |
indexing.connect("dense_doc_embedder", "writer") | |
indexing.run({"document_joiner": {"documents": documents}}) | |
end_time = time.time() | |
elapsed_time = end_time - start_time | |
return {"filename": file.filename, "message": "Done", "execution_time": elapsed_time} | |
def search(prompt: str): | |
import time | |
from haystack import Document, Pipeline | |
from haystack_integrations.components.retrievers.qdrant import QdrantHybridRetriever | |
from haystack_integrations.components.embedders.fastembed import ( | |
FastembedTextEmbedder, | |
FastembedSparseTextEmbedder | |
) | |
from haystack.components.rankers import TransformersSimilarityRanker | |
from haystack.components.joiners import DocumentJoiner | |
from haystack.components.generators import OpenAIGenerator | |
from haystack.utils import Secret | |
from haystack.components.builders import PromptBuilder | |
start_time = time.time() | |
# Querying | |
template = """ | |
Given the following information, answer the question. | |
Context: | |
{% for document in documents %} | |
{{ document.content }} | |
{% endfor %} | |
Question: {{question}} | |
Answer: | |
""" | |
prompt_builder = PromptBuilder(template=template) | |
generator = OpenAIGenerator( | |
api_key=Secret.from_env_var("OCTOAI_TOKEN"), | |
api_base_url="https://text.octoai.run/v1", | |
model="mixtral-8x22b-finetuned", | |
generation_kwargs = {"max_tokens": 512} | |
) | |
metadata_extractor = QueryMetadataExtractor() | |
querying = Pipeline() | |
querying.add_component("sparse_text_embedder", FastembedSparseTextEmbedder(model="Qdrant/bm42-all-minilm-l6-v2-attentions")) | |
querying.add_component("dense_text_embedder", FastembedTextEmbedder( | |
model="sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2", prefix="Represent this sentence for searching relevant passages: ") | |
) | |
querying.add_component(instance=metadata_extractor, name="metadata_extractor") | |
querying.add_component("retriever", QdrantHybridRetriever(document_store=document_store)) | |
querying.add_component("document_joiner", DocumentJoiner()) | |
querying.add_component("ranker", TransformersSimilarityRanker(model="BAAI/bge-m3")) | |
querying.add_component("prompt_builder", prompt_builder) | |
querying.add_component("llm", generator) | |
querying.connect("sparse_text_embedder.sparse_embedding", "retriever.query_sparse_embedding") | |
querying.connect("dense_text_embedder.embedding", "retriever.query_embedding") | |
querying.connect("metadata_extractor.filters", "retriever.filters") | |
querying.connect("retriever", "document_joiner") | |
querying.connect("document_joiner", "ranker") | |
querying.connect("ranker.documents", "prompt_builder.documents") | |
querying.connect("prompt_builder", "llm") | |
querying.debug=True | |
metadata_fields = {"publish_date", "publisher", "document_type"} | |
results = querying.run( | |
{ | |
"dense_text_embedder": {"text": prompt}, | |
"sparse_text_embedder": {"text": prompt}, | |
"metadata_extractor": {"query": prompt, "metadata_fields": metadata_fields}, | |
"ranker": {"query": prompt}, | |
"prompt_builder": {"question": prompt} | |
} | |
) | |
end_time = time.time() | |
elapsed_time = end_time - start_time | |
print(f"Execution time: {elapsed_time:.6f} seconds") | |
return results | |
async def download_database(): | |
import time | |
start_time = time.time() | |
# Path to the database directory | |
database_dir = join(os.getcwd(), 'database') | |
# Path for the zip file | |
zip_path = join(os.getcwd(), 'database.zip') | |
# Create a zip file of the database directory | |
shutil.make_archive(zip_path.replace('.zip', ''), 'zip', database_dir) | |
end_time = time.time() | |
elapsed_time = end_time - start_time | |
print(f"Execution time: {elapsed_time:.6f} seconds") | |
# Return the zip file as a response for download | |
return FileResponse(zip_path, media_type='application/zip', filename='database.zip') | |
async def convert_upload_file(file: UploadFile = File(...)): | |
import pytesseract | |
from pdf2image import convert_from_path | |
from octoai.client import OctoAI | |
from octoai.text_gen import ChatCompletionResponseFormat, ChatMessage | |
file_savePath = join(temp_path,file.filename) | |
with open(file_savePath,'wb') as f: | |
shutil.copyfileobj(file.file, f) | |
# convert PDF to image | |
images = convert_from_path(file_savePath) | |
text="" | |
first_page = "" | |
# Extract text from images | |
for image in images: | |
ocr_text = pytesseract.image_to_string(image,lang='vie') | |
if first_page=="": | |
first_page = ocr_text | |
text=text+ocr_text+'\n' | |
client = OctoAI() | |
completion = client.text_gen.create_chat_completion( | |
model="mixtral-8x22b-finetuned", | |
messages=[ | |
ChatMessage(role="system", content="You are a helpful assistant."), | |
ChatMessage(role="user", content=first_page), | |
], | |
presence_penalty=0, | |
temperature=0.1, | |
top_p=0.9, | |
response_format=ChatCompletionResponseFormat( | |
type="json_object", | |
schema=Law.model_json_schema(), | |
), | |
) | |
return {'content':text,'metadate':completion.choices[0].message.content} | |
def api_home(): | |
return {'detail': 'Welcome to FastAPI Qdrant importer!'} | |