File size: 3,399 Bytes
9177215 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 |
from io import BytesIO
from typing import List, Any, Optional
import re
import docx2txt
from langchain.docstore.document import Document
import fitz
from hashlib import md5
from abc import abstractmethod, ABC
from copy import deepcopy
class File(ABC):
"""Represents an uploaded file comprised of Documents"""
def __init__(
self,
name: str,
id: str,
metadata: Optional[dict[str, Any]] = None,
docs: Optional[List[Document]] = None,
):
self.name = name
self.id = id
self.metadata = metadata or {}
self.docs = docs or []
@classmethod
@abstractmethod
def from_bytes(cls, file: BytesIO) -> "File":
"""Creates a File from a BytesIO object"""
def __repr__(self) -> str:
return (
f"File(name={self.name}, id={self.id},"
" metadata={self.metadata}, docs={self.docs})"
)
def __str__(self) -> str:
return f"File(name={self.name}, id={self.id}, metadata={self.metadata})"
def copy(self) -> "File":
"""Create a deep copy of this File"""
return self.__class__(
name=self.name,
id=self.id,
metadata=deepcopy(self.metadata),
docs=deepcopy(self.docs),
)
def strip_consecutive_newlines(text: str) -> str:
"""Strips consecutive newlines from a string
possibly with whitespace in between
"""
return re.sub(r"\s*\n\s*", "\n", text)
class DocxFile(File):
@classmethod
def from_bytes(cls, file: BytesIO) -> "DocxFile":
text = docx2txt.process(file)
text = strip_consecutive_newlines(text)
doc = Document(page_content=text.strip())
doc.metadata["source"] = "p-1"
return cls(name=file.name, id=md5(file.read()).hexdigest(), docs=[doc])
class PdfFile(File):
@classmethod
def from_bytes(cls, file: BytesIO) -> "PdfFile":
pdf = fitz.open(stream=file.read(), filetype="pdf") # type: ignore
docs = []
for i, page in enumerate(pdf):
text = page.get_text(sort=True)
text = strip_consecutive_newlines(text)
doc = Document(page_content=text.strip())
doc.metadata["page"] = i + 1
doc.metadata["source"] = f"p-{i+1}"
docs.append(doc)
# file.read() mutates the file object, which can affect caching
# so we need to reset the file pointer to the beginning
file.seek(0)
return cls(name=file.name, id=md5(file.read()).hexdigest(), docs=docs)
class TxtFile(File):
@classmethod
def from_bytes(cls, file: BytesIO) -> "TxtFile":
text = file.read().decode("utf-8", errors="replace")
text = strip_consecutive_newlines(text)
file.seek(0)
doc = Document(page_content=text.strip())
doc.metadata["source"] = "p-1"
return cls(name=file.name, id=md5(file.read()).hexdigest(), docs=[doc])
def read_file(file: BytesIO) -> File:
"""Reads an uploaded file and returns a File object"""
if file.name.lower().endswith(".docx"):
return DocxFile.from_bytes(file)
elif file.name.lower().endswith(".pdf"):
return PdfFile.from_bytes(file)
elif file.name.lower().endswith(".txt"):
return TxtFile.from_bytes(file)
else:
raise NotImplementedError(f"File type {file.name.split('.')[-1]} not supported")
|