add recursive text splitter
Browse files- BuildingAChainlitApp.md +3 -4
- aimakerspace/text_utils.py +2 -4
- app.py +15 -12
- requirements.txt +2 -1
- richard/text_utils.py +62 -3
BuildingAChainlitApp.md
CHANGED
@@ -123,7 +123,7 @@ def process_text_file(file: AskFileResponse):
|
|
123 |
|
124 |
text_loader = TextFileLoader(temp_file_path)
|
125 |
documents = text_loader.load_documents()
|
126 |
-
texts = text_splitter.
|
127 |
return texts
|
128 |
```
|
129 |
|
@@ -286,7 +286,7 @@ Code was modified to support pdf documents in the following areas:
|
|
286 |
raise ValueError(
|
287 |
f"Unsupported file type: {self.temp_file_path}"
|
288 |
)
|
289 |
-
return text_splitter.
|
290 |
else:
|
291 |
raise ValueError(
|
292 |
"Not a file"
|
@@ -297,9 +297,8 @@ Code was modified to support pdf documents in the following areas:
|
|
297 |
self.documents.append(f.read())
|
298 |
|
299 |
def load_pdf_file(self):
|
300 |
-
|
301 |
pdf_document = fitz.open(self.temp_file_path)
|
302 |
-
print(len(pdf_document))
|
303 |
for page_num in range(len(pdf_document)):
|
304 |
page = pdf_document.load_page(page_num)
|
305 |
text = page.get_text()
|
|
|
123 |
|
124 |
text_loader = TextFileLoader(temp_file_path)
|
125 |
documents = text_loader.load_documents()
|
126 |
+
texts = text_splitter.split_text(documents)
|
127 |
return texts
|
128 |
```
|
129 |
|
|
|
286 |
raise ValueError(
|
287 |
f"Unsupported file type: {self.temp_file_path}"
|
288 |
)
|
289 |
+
return text_splitter.split_text(self.documents)
|
290 |
else:
|
291 |
raise ValueError(
|
292 |
"Not a file"
|
|
|
297 |
self.documents.append(f.read())
|
298 |
|
299 |
def load_pdf_file(self):
|
300 |
+
|
301 |
pdf_document = fitz.open(self.temp_file_path)
|
|
|
302 |
for page_num in range(len(pdf_document)):
|
303 |
page = pdf_document.load_page(page_num)
|
304 |
text = page.get_text()
|
aimakerspace/text_utils.py
CHANGED
@@ -1,7 +1,6 @@
|
|
1 |
import os
|
2 |
from typing import List
|
3 |
|
4 |
-
|
5 |
class TextFileLoader:
|
6 |
def __init__(self, path: str, encoding: str = "utf-8"):
|
7 |
self.documents = []
|
@@ -55,18 +54,17 @@ class CharacterTextSplitter:
|
|
55 |
chunks.append(text[i : i + self.chunk_size])
|
56 |
return chunks
|
57 |
|
58 |
-
def
|
59 |
chunks = []
|
60 |
for text in texts:
|
61 |
chunks.extend(self.split(text))
|
62 |
return chunks
|
63 |
|
64 |
-
|
65 |
if __name__ == "__main__":
|
66 |
loader = TextFileLoader("data/KingLear.txt")
|
67 |
loader.load()
|
68 |
splitter = CharacterTextSplitter()
|
69 |
-
chunks = splitter.
|
70 |
print(len(chunks))
|
71 |
print(chunks[0])
|
72 |
print("--------")
|
|
|
1 |
import os
|
2 |
from typing import List
|
3 |
|
|
|
4 |
class TextFileLoader:
|
5 |
def __init__(self, path: str, encoding: str = "utf-8"):
|
6 |
self.documents = []
|
|
|
54 |
chunks.append(text[i : i + self.chunk_size])
|
55 |
return chunks
|
56 |
|
57 |
+
def split_text(self, texts: List[str]) -> List[str]:
|
58 |
chunks = []
|
59 |
for text in texts:
|
60 |
chunks.extend(self.split(text))
|
61 |
return chunks
|
62 |
|
|
|
63 |
if __name__ == "__main__":
|
64 |
loader = TextFileLoader("data/KingLear.txt")
|
65 |
loader.load()
|
66 |
splitter = CharacterTextSplitter()
|
67 |
+
chunks = splitter.split_text(loader.documents)
|
68 |
print(len(chunks))
|
69 |
print(chunks[0])
|
70 |
print("--------")
|
app.py
CHANGED
@@ -14,8 +14,10 @@ from richard.text_utils import FileLoader
|
|
14 |
from richard.pipeline import RetrievalAugmentedQAPipeline
|
15 |
# from richard.vector_database import QdrantDatabase
|
16 |
from qdrant_client import QdrantClient
|
17 |
-
from langchain.vectorstores import Qdrant
|
18 |
|
|
|
|
|
|
|
19 |
|
20 |
system_template = """\
|
21 |
Use the following context to answer a users question.
|
@@ -33,11 +35,6 @@ Question:
|
|
33 |
"""
|
34 |
user_role_prompt = UserRolePrompt(user_prompt_template)
|
35 |
|
36 |
-
def process_file(file: AskFileResponse):
|
37 |
-
fileLoader = FileLoader()
|
38 |
-
return fileLoader.load_file(file)
|
39 |
-
|
40 |
-
|
41 |
@cl.on_chat_start
|
42 |
async def on_chat_start():
|
43 |
res = await cl.AskActionMessage(
|
@@ -65,6 +62,17 @@ async def on_chat_start():
|
|
65 |
)
|
66 |
await msg.send()
|
67 |
use_qdrant = False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
68 |
files = None
|
69 |
# Wait for the user to upload a file
|
70 |
while not files:
|
@@ -82,8 +90,7 @@ async def on_chat_start():
|
|
82 |
)
|
83 |
await msg.send()
|
84 |
|
85 |
-
|
86 |
-
texts = process_file(file)
|
87 |
|
88 |
msg = cl.Message(
|
89 |
content=f"Resulted in {len(texts)} chunks", disable_human_feedback=True
|
@@ -99,15 +106,11 @@ async def on_chat_start():
|
|
99 |
else:
|
100 |
embedding_model = EmbeddingModel()
|
101 |
if use_qdrant_type == "Local":
|
102 |
-
from qdrant_client.http.models import OptimizersConfig
|
103 |
-
print("Using qdrant local")
|
104 |
qdrant_client = QdrantClient(location=":memory:")
|
105 |
-
|
106 |
vector_params = VectorParams(
|
107 |
size=1536, # vector size
|
108 |
distance="Cosine" # distance metric
|
109 |
)
|
110 |
-
|
111 |
qdrant_client.recreate_collection(
|
112 |
collection_name="my_collection",
|
113 |
vectors_config={"default": vector_params},
|
|
|
14 |
from richard.pipeline import RetrievalAugmentedQAPipeline
|
15 |
# from richard.vector_database import QdrantDatabase
|
16 |
from qdrant_client import QdrantClient
|
|
|
17 |
|
18 |
+
def process_file(file, use_rct):
|
19 |
+
fileLoader = FileLoader()
|
20 |
+
return fileLoader.load_file(file, use_rct)
|
21 |
|
22 |
system_template = """\
|
23 |
Use the following context to answer a users question.
|
|
|
35 |
"""
|
36 |
user_role_prompt = UserRolePrompt(user_prompt_template)
|
37 |
|
|
|
|
|
|
|
|
|
|
|
38 |
@cl.on_chat_start
|
39 |
async def on_chat_start():
|
40 |
res = await cl.AskActionMessage(
|
|
|
62 |
)
|
63 |
await msg.send()
|
64 |
use_qdrant = False
|
65 |
+
use_rct = False
|
66 |
+
res = await cl.AskActionMessage(
|
67 |
+
content="Do you want to use RecursiveCharacterTextSplitter?",
|
68 |
+
actions=[
|
69 |
+
cl.Action(name="yes", value="yes", label="✅ Yes"),
|
70 |
+
cl.Action(name="no", value="no", label="❌ No"),
|
71 |
+
],
|
72 |
+
).send()
|
73 |
+
if res and res.get("value") == "yes":
|
74 |
+
use_rct = True
|
75 |
+
|
76 |
files = None
|
77 |
# Wait for the user to upload a file
|
78 |
while not files:
|
|
|
90 |
)
|
91 |
await msg.send()
|
92 |
|
93 |
+
texts = process_file(file, use_rct)
|
|
|
94 |
|
95 |
msg = cl.Message(
|
96 |
content=f"Resulted in {len(texts)} chunks", disable_human_feedback=True
|
|
|
106 |
else:
|
107 |
embedding_model = EmbeddingModel()
|
108 |
if use_qdrant_type == "Local":
|
|
|
|
|
109 |
qdrant_client = QdrantClient(location=":memory:")
|
|
|
110 |
vector_params = VectorParams(
|
111 |
size=1536, # vector size
|
112 |
distance="Cosine" # distance metric
|
113 |
)
|
|
|
114 |
qdrant_client.recreate_collection(
|
115 |
collection_name="my_collection",
|
116 |
vectors_config={"default": vector_params},
|
requirements.txt
CHANGED
@@ -2,4 +2,5 @@ numpy==1.26.4
|
|
2 |
chainlit==0.7.700 # 1.1.402
|
3 |
openai==1.3.5
|
4 |
pymupdf==1.24.9
|
5 |
-
qdrant-client==1.11.0
|
|
|
|
2 |
chainlit==0.7.700 # 1.1.402
|
3 |
openai==1.3.5
|
4 |
pymupdf==1.24.9
|
5 |
+
qdrant-client==1.11.0
|
6 |
+
langchain-text-splitters
|
richard/text_utils.py
CHANGED
@@ -1,8 +1,14 @@
|
|
1 |
import os
|
|
|
2 |
import fitz
|
3 |
import tempfile
|
4 |
from aimakerspace.text_utils import CharacterTextSplitter
|
|
|
5 |
|
|
|
|
|
|
|
|
|
6 |
class FileLoader:
|
7 |
|
8 |
def __init__(self, encoding: str = "utf-8"):
|
@@ -11,7 +17,11 @@ class FileLoader:
|
|
11 |
self.temp_file_path = ""
|
12 |
|
13 |
|
14 |
-
def load_file(self, file,
|
|
|
|
|
|
|
|
|
15 |
file_extension = os.path.splitext(file.name)[1].lower()
|
16 |
with tempfile.NamedTemporaryFile(mode="wb", delete=False, suffix=file_extension) as temp_file:
|
17 |
self.temp_file_path = temp_file.name
|
@@ -26,12 +36,14 @@ class FileLoader:
|
|
26 |
raise ValueError(
|
27 |
f"Unsupported file type: {self.temp_file_path}"
|
28 |
)
|
29 |
-
|
|
|
30 |
else:
|
31 |
raise ValueError(
|
32 |
"Not a file"
|
33 |
)
|
34 |
|
|
|
35 |
def load_text_file(self):
|
36 |
with open(self.temp_file_path, "r", encoding=self.encoding) as f:
|
37 |
self.documents.append(f.read())
|
@@ -43,4 +55,51 @@ class FileLoader:
|
|
43 |
for page_num in range(len(pdf_document)):
|
44 |
page = pdf_document.load_page(page_num)
|
45 |
text = page.get_text()
|
46 |
-
self.documents.append(text)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import os
|
2 |
+
from typing import List
|
3 |
import fitz
|
4 |
import tempfile
|
5 |
from aimakerspace.text_utils import CharacterTextSplitter
|
6 |
+
from langchain_text_splitters import RecursiveCharacterTextSplitter
|
7 |
|
8 |
+
# load the file
|
9 |
+
|
10 |
+
|
11 |
+
|
12 |
class FileLoader:
|
13 |
|
14 |
def __init__(self, encoding: str = "utf-8"):
|
|
|
17 |
self.temp_file_path = ""
|
18 |
|
19 |
|
20 |
+
def load_file(self, file, use_rct):
|
21 |
+
if use_rct:
|
22 |
+
text_splitter=MyRecursiveCharacterTextSplitter()
|
23 |
+
else:
|
24 |
+
text_splitter=CharacterTextSplitter()
|
25 |
file_extension = os.path.splitext(file.name)[1].lower()
|
26 |
with tempfile.NamedTemporaryFile(mode="wb", delete=False, suffix=file_extension) as temp_file:
|
27 |
self.temp_file_path = temp_file.name
|
|
|
36 |
raise ValueError(
|
37 |
f"Unsupported file type: {self.temp_file_path}"
|
38 |
)
|
39 |
+
print(self.documents)
|
40 |
+
return text_splitter.split_text(self.documents)
|
41 |
else:
|
42 |
raise ValueError(
|
43 |
"Not a file"
|
44 |
)
|
45 |
|
46 |
+
|
47 |
def load_text_file(self):
|
48 |
with open(self.temp_file_path, "r", encoding=self.encoding) as f:
|
49 |
self.documents.append(f.read())
|
|
|
55 |
for page_num in range(len(pdf_document)):
|
56 |
page = pdf_document.load_page(page_num)
|
57 |
text = page.get_text()
|
58 |
+
self.documents.append(text)
|
59 |
+
|
60 |
+
class CharacterTextSplitter:
|
61 |
+
def __init__(
|
62 |
+
self,
|
63 |
+
chunk_size: int = 1000,
|
64 |
+
chunk_overlap: int = 200,
|
65 |
+
):
|
66 |
+
assert (
|
67 |
+
chunk_size > chunk_overlap
|
68 |
+
), "Chunk size must be greater than chunk overlap"
|
69 |
+
|
70 |
+
self.chunk_size = chunk_size
|
71 |
+
self.chunk_overlap = chunk_overlap
|
72 |
+
|
73 |
+
def split(self, text: str) -> List[str]:
|
74 |
+
chunks = []
|
75 |
+
for i in range(0, len(text), self.chunk_size - self.chunk_overlap):
|
76 |
+
chunks.append(text[i : i + self.chunk_size])
|
77 |
+
return chunks
|
78 |
+
|
79 |
+
def split_text(self, texts: List[str]) -> List[str]:
|
80 |
+
chunks = []
|
81 |
+
for text in texts:
|
82 |
+
chunks.extend(self.split(text))
|
83 |
+
return chunks
|
84 |
+
|
85 |
+
|
86 |
+
|
87 |
+
class MyRecursiveCharacterTextSplitter:
|
88 |
+
def __init__(
|
89 |
+
self
|
90 |
+
):
|
91 |
+
self.RCTS = RecursiveCharacterTextSplitter(
|
92 |
+
chunk_size=1000,
|
93 |
+
chunk_overlap=20,
|
94 |
+
length_function=len,
|
95 |
+
separators=["\n\n", "\n", " ", ""]
|
96 |
+
)
|
97 |
+
|
98 |
+
def split_text(self, texts: List[str]) -> List[str]:
|
99 |
+
all_chunks = []
|
100 |
+
for doc in texts:
|
101 |
+
chunks = self.RCTS.split_text(doc)
|
102 |
+
all_chunks.extend(chunks)
|
103 |
+
return all_chunks
|
104 |
+
|
105 |
+
|