Spaces:
Runtime error
Runtime error
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
from langchain.vectorstores.faiss import FAISS | |
from langchain import OpenAI | |
from langchain.chains.qa_with_sources import load_qa_with_sources_chain | |
from langchain.embeddings.openai import OpenAIEmbeddings | |
from langchain.llms import OpenAI | |
from langchain.docstore.document import Document | |
from langchain.vectorstores import FAISS, VectorStore | |
import docx2txt | |
from typing import List, Dict, Any, Union, Text, Tuple, Iterable | |
import re | |
from io import BytesIO | |
import streamlit as st | |
from .prompts import STUFF_PROMPT | |
from pypdf import PdfReader | |
from openai.error import AuthenticationError | |
class PDFFile: | |
"""A PDF file class for typing purposes.""" | |
def is_pdf(file:Any) -> bool: | |
return file.name.endswith(".pdf") | |
class DocxFile: | |
"""A Docx file class for typing purposes.""" | |
def is_docx(file:Any) -> bool: | |
return file.name.endswith(".docx") | |
class TxtFile: | |
"""A Txt file class for typing purposes.""" | |
def is_txt(file:Any) -> bool: | |
return file.name.endswith(".txt") | |
class CodeFile: | |
"""A scripting-file class for typing purposes.""" | |
def is_code(file:Any) -> bool: | |
return file.name.split(".")[1] in [".py", ".json", ".html", ".css", ".md"] | |
class HashDocument(Document): | |
"""A document that uses the page content as the hash.""" | |
def __hash__(self): | |
content = self.page_content + "".join(self.metadata[k] for k in self.metadata.keys()) | |
return hash(content) | |
def parse_docx(file: BytesIO) -> str: | |
text = docx2txt.process(file) | |
# Remove multiple newlines | |
text = re.sub(r"\n\s*\n", "\n\n", text) | |
return text | |
def parse_pdf(file: BytesIO) -> List[str]: | |
pdf = PdfReader(file) | |
output = [] | |
for page in pdf.pages: | |
text = page.extract_text() | |
# Merge hyphenated words | |
text = re.sub(r"(\w+)-\n(\w+)", r"\1\2", text) | |
# Fix newlines in the middle of sentences | |
text = re.sub(r"(?<!\n\s)\n(?!\s\n)", " ", text.strip()) | |
# Remove multiple newlines | |
text = re.sub(r"\n\s*\n", "\n\n", text) | |
output.append(text) | |
return output | |
def parse_txt(file: BytesIO) -> str: | |
text = file.read().decode("utf-8") | |
# Remove multiple newlines | |
text = re.sub(r"\n\s*\n", "\n\n", text) | |
return text | |
def get_text_splitter( | |
chunk_size:int=500, | |
chunk_overlap:int=50, | |
separators:Iterable[Text]= ["\n\n", "\n", ".", "!", "?", ",", " ", ""])->RecursiveCharacterTextSplitter: | |
"""Returns a text splitter instance with the given parameters. Cached for performance.""" | |
# text splitter to split the text into chunks | |
text_splitter = RecursiveCharacterTextSplitter( | |
chunk_size=chunk_size, # a limited chunk size ensures smaller chunks and more precise answers | |
separators=separators, # a list of separators to split the text on | |
chunk_overlap=chunk_overlap, # minimal overlap to capture sematic overlap across chunks | |
) | |
return text_splitter | |
def text_to_docs(text: Union[Text, Tuple[Text]]) -> List[Document]: | |
""" | |
Converts a string or frozenset of strings to a list of Documents | |
with metadata. | |
""" | |
# sanity check on the input provided | |
if not isinstance(text, (str, tuple)): | |
raise ValueError("Text must be either a string or a list of strings. Got: {type(text)}") | |
elif isinstance(text, str): | |
# Take a single string as one page - make it a tuple so that is hashable | |
text = (text, ) | |
if isinstance(text, tuple): | |
# map each page into a document instance | |
page_docs = [HashDocument(page_content=page) for page in text] | |
# Add page numbers as metadata | |
for i, doc in enumerate(page_docs): | |
doc.metadata["page"] = i + 1 | |
# Split pages into chunks | |
doc_chunks = [] | |
# Get the text splitter | |
text_splitter = get_text_splitter() | |
for doc in page_docs: | |
# this splits the page into chunks | |
chunks = text_splitter.split_text(doc.page_content) | |
for i, chunk in enumerate(chunks): | |
# Create a new document for each individual chunk | |
doc = HashDocument( | |
page_content=chunk, metadata={"page": doc.metadata["page"], "chunk": i} | |
) | |
# Add sources to metadata for retrieval later on | |
doc.metadata["source"] = f"{doc.metadata['page']}-{doc.metadata['chunk']}" | |
doc_chunks.append(doc) | |
return doc_chunks | |
def embed_docs(_docs: Tuple[Document]) -> VectorStore: | |
"""Embeds a list of Documents and returns a FAISS index""" | |
# Embed the chunks | |
embeddings = OpenAIEmbeddings(openai_api_key=st.session_state.get("OPENAI_API_KEY")) | |
index = FAISS.from_documents(list(_docs), embeddings) | |
return index | |
def search_docs(_index: VectorStore, query: str, k:int=5) -> List[Document]: | |
"""Searches a FAISS index for similar chunks to the query | |
and returns a list of Documents.""" | |
# Search for similar chunks | |
docs = _index.similarity_search(query, k=k) | |
return docs | |
def get_answer(_docs: List[Document], query: str) -> Dict[str, Any]: | |
"""Gets an answer to a question from a list of Documents.""" | |
# Get the answer | |
chain = load_qa_with_sources_chain( | |
OpenAI(temperature=0, | |
openai_api_key=st.session_state.get("OPENAI_API_KEY")), | |
chain_type="stuff", | |
prompt=STUFF_PROMPT | |
) | |
# also returnig the text of the source used to form the answer | |
answer = chain( | |
{"input_documents": _docs, "question": query} | |
) | |
return answer | |
def get_sources(answer: Dict[str, Any], docs: List[Document]) -> List[Document]: | |
"""Gets the source documents for an answer.""" | |
# Get sources for the answer | |
source_keys = [s for s in answer["output_text"].split("SOURCES: ")[-1].split(", ")] | |
source_docs = [] | |
for doc in docs: | |
if doc.metadata["source"] in source_keys: | |
source_docs.append(doc) | |
return source_docs | |
def wrap_text_in_html(text: str) -> str: | |
"""Wraps each text block separated by newlines in <p> tags""" | |
if isinstance(text, list): | |
# Add horizontal rules between pages | |
text = "\n<hr/>\n".join(text) | |
return "".join([f"<p>{line}</p>" for line in text.split("\n")]) |