rchrdgwr commited on
Commit
0fbd1a9
1 Parent(s): 0614fbf

add recursive text splitter

Browse files
BuildingAChainlitApp.md CHANGED
@@ -123,7 +123,7 @@ def process_text_file(file: AskFileResponse):
123
 
124
  text_loader = TextFileLoader(temp_file_path)
125
  documents = text_loader.load_documents()
126
- texts = text_splitter.split_texts(documents)
127
  return texts
128
  ```
129
 
@@ -286,7 +286,7 @@ Code was modified to support pdf documents in the following areas:
286
  raise ValueError(
287
  f"Unsupported file type: {self.temp_file_path}"
288
  )
289
- return text_splitter.split_texts(self.documents)
290
  else:
291
  raise ValueError(
292
  "Not a file"
@@ -297,9 +297,8 @@ Code was modified to support pdf documents in the following areas:
297
  self.documents.append(f.read())
298
 
299
  def load_pdf_file(self):
300
- print("load_pdf_file()")
301
  pdf_document = fitz.open(self.temp_file_path)
302
- print(len(pdf_document))
303
  for page_num in range(len(pdf_document)):
304
  page = pdf_document.load_page(page_num)
305
  text = page.get_text()
 
123
 
124
  text_loader = TextFileLoader(temp_file_path)
125
  documents = text_loader.load_documents()
126
+ texts = text_splitter.split_text(documents)
127
  return texts
128
  ```
129
 
 
286
  raise ValueError(
287
  f"Unsupported file type: {self.temp_file_path}"
288
  )
289
+ return text_splitter.split_text(self.documents)
290
  else:
291
  raise ValueError(
292
  "Not a file"
 
297
  self.documents.append(f.read())
298
 
299
  def load_pdf_file(self):
300
+
301
  pdf_document = fitz.open(self.temp_file_path)
 
302
  for page_num in range(len(pdf_document)):
303
  page = pdf_document.load_page(page_num)
304
  text = page.get_text()
aimakerspace/text_utils.py CHANGED
@@ -1,7 +1,6 @@
1
  import os
2
  from typing import List
3
 
4
-
5
  class TextFileLoader:
6
  def __init__(self, path: str, encoding: str = "utf-8"):
7
  self.documents = []
@@ -55,18 +54,17 @@ class CharacterTextSplitter:
55
  chunks.append(text[i : i + self.chunk_size])
56
  return chunks
57
 
58
- def split_texts(self, texts: List[str]) -> List[str]:
59
  chunks = []
60
  for text in texts:
61
  chunks.extend(self.split(text))
62
  return chunks
63
 
64
-
65
  if __name__ == "__main__":
66
  loader = TextFileLoader("data/KingLear.txt")
67
  loader.load()
68
  splitter = CharacterTextSplitter()
69
- chunks = splitter.split_texts(loader.documents)
70
  print(len(chunks))
71
  print(chunks[0])
72
  print("--------")
 
1
  import os
2
  from typing import List
3
 
 
4
  class TextFileLoader:
5
  def __init__(self, path: str, encoding: str = "utf-8"):
6
  self.documents = []
 
54
  chunks.append(text[i : i + self.chunk_size])
55
  return chunks
56
 
57
+ def split_text(self, texts: List[str]) -> List[str]:
58
  chunks = []
59
  for text in texts:
60
  chunks.extend(self.split(text))
61
  return chunks
62
 
 
63
  if __name__ == "__main__":
64
  loader = TextFileLoader("data/KingLear.txt")
65
  loader.load()
66
  splitter = CharacterTextSplitter()
67
+ chunks = splitter.split_text(loader.documents)
68
  print(len(chunks))
69
  print(chunks[0])
70
  print("--------")
app.py CHANGED
@@ -14,8 +14,10 @@ from richard.text_utils import FileLoader
14
  from richard.pipeline import RetrievalAugmentedQAPipeline
15
  # from richard.vector_database import QdrantDatabase
16
  from qdrant_client import QdrantClient
17
- from langchain.vectorstores import Qdrant
18
 
 
 
 
19
 
20
  system_template = """\
21
  Use the following context to answer a users question.
@@ -33,11 +35,6 @@ Question:
33
  """
34
  user_role_prompt = UserRolePrompt(user_prompt_template)
35
 
36
- def process_file(file: AskFileResponse):
37
- fileLoader = FileLoader()
38
- return fileLoader.load_file(file)
39
-
40
-
41
  @cl.on_chat_start
42
  async def on_chat_start():
43
  res = await cl.AskActionMessage(
@@ -65,6 +62,17 @@ async def on_chat_start():
65
  )
66
  await msg.send()
67
  use_qdrant = False
 
 
 
 
 
 
 
 
 
 
 
68
  files = None
69
  # Wait for the user to upload a file
70
  while not files:
@@ -82,8 +90,7 @@ async def on_chat_start():
82
  )
83
  await msg.send()
84
 
85
- # load the file
86
- texts = process_file(file)
87
 
88
  msg = cl.Message(
89
  content=f"Resulted in {len(texts)} chunks", disable_human_feedback=True
@@ -99,15 +106,11 @@ async def on_chat_start():
99
  else:
100
  embedding_model = EmbeddingModel()
101
  if use_qdrant_type == "Local":
102
- from qdrant_client.http.models import OptimizersConfig
103
- print("Using qdrant local")
104
  qdrant_client = QdrantClient(location=":memory:")
105
-
106
  vector_params = VectorParams(
107
  size=1536, # vector size
108
  distance="Cosine" # distance metric
109
  )
110
-
111
  qdrant_client.recreate_collection(
112
  collection_name="my_collection",
113
  vectors_config={"default": vector_params},
 
14
  from richard.pipeline import RetrievalAugmentedQAPipeline
15
  # from richard.vector_database import QdrantDatabase
16
  from qdrant_client import QdrantClient
 
17
 
18
+ def process_file(file, use_rct):
19
+ fileLoader = FileLoader()
20
+ return fileLoader.load_file(file, use_rct)
21
 
22
  system_template = """\
23
  Use the following context to answer a users question.
 
35
  """
36
  user_role_prompt = UserRolePrompt(user_prompt_template)
37
 
 
 
 
 
 
38
  @cl.on_chat_start
39
  async def on_chat_start():
40
  res = await cl.AskActionMessage(
 
62
  )
63
  await msg.send()
64
  use_qdrant = False
65
+ use_rct = False
66
+ res = await cl.AskActionMessage(
67
+ content="Do you want to use RecursiveCharacterTextSplitter?",
68
+ actions=[
69
+ cl.Action(name="yes", value="yes", label="✅ Yes"),
70
+ cl.Action(name="no", value="no", label="❌ No"),
71
+ ],
72
+ ).send()
73
+ if res and res.get("value") == "yes":
74
+ use_rct = True
75
+
76
  files = None
77
  # Wait for the user to upload a file
78
  while not files:
 
90
  )
91
  await msg.send()
92
 
93
+ texts = process_file(file, use_rct)
 
94
 
95
  msg = cl.Message(
96
  content=f"Resulted in {len(texts)} chunks", disable_human_feedback=True
 
106
  else:
107
  embedding_model = EmbeddingModel()
108
  if use_qdrant_type == "Local":
 
 
109
  qdrant_client = QdrantClient(location=":memory:")
 
110
  vector_params = VectorParams(
111
  size=1536, # vector size
112
  distance="Cosine" # distance metric
113
  )
 
114
  qdrant_client.recreate_collection(
115
  collection_name="my_collection",
116
  vectors_config={"default": vector_params},
requirements.txt CHANGED
@@ -2,4 +2,5 @@ numpy==1.26.4
2
  chainlit==0.7.700 # 1.1.402
3
  openai==1.3.5
4
  pymupdf==1.24.9
5
- qdrant-client==1.11.0
 
 
2
  chainlit==0.7.700 # 1.1.402
3
  openai==1.3.5
4
  pymupdf==1.24.9
5
+ qdrant-client==1.11.0
6
+ langchain-text-splitters
richard/text_utils.py CHANGED
@@ -1,8 +1,14 @@
1
  import os
 
2
  import fitz
3
  import tempfile
4
  from aimakerspace.text_utils import CharacterTextSplitter
 
5
 
 
 
 
 
6
  class FileLoader:
7
 
8
  def __init__(self, encoding: str = "utf-8"):
@@ -11,7 +17,11 @@ class FileLoader:
11
  self.temp_file_path = ""
12
 
13
 
14
- def load_file(self, file, text_splitter=CharacterTextSplitter()):
 
 
 
 
15
  file_extension = os.path.splitext(file.name)[1].lower()
16
  with tempfile.NamedTemporaryFile(mode="wb", delete=False, suffix=file_extension) as temp_file:
17
  self.temp_file_path = temp_file.name
@@ -26,12 +36,14 @@ class FileLoader:
26
  raise ValueError(
27
  f"Unsupported file type: {self.temp_file_path}"
28
  )
29
- return text_splitter.split_texts(self.documents)
 
30
  else:
31
  raise ValueError(
32
  "Not a file"
33
  )
34
 
 
35
  def load_text_file(self):
36
  with open(self.temp_file_path, "r", encoding=self.encoding) as f:
37
  self.documents.append(f.read())
@@ -43,4 +55,51 @@ class FileLoader:
43
  for page_num in range(len(pdf_document)):
44
  page = pdf_document.load_page(page_num)
45
  text = page.get_text()
46
- self.documents.append(text)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
2
+ from typing import List
3
  import fitz
4
  import tempfile
5
  from aimakerspace.text_utils import CharacterTextSplitter
6
+ from langchain_text_splitters import RecursiveCharacterTextSplitter
7
 
8
+ # load the file
9
+
10
+
11
+
12
  class FileLoader:
13
 
14
  def __init__(self, encoding: str = "utf-8"):
 
17
  self.temp_file_path = ""
18
 
19
 
20
+ def load_file(self, file, use_rct):
21
+ if use_rct:
22
+ text_splitter=MyRecursiveCharacterTextSplitter()
23
+ else:
24
+ text_splitter=CharacterTextSplitter()
25
  file_extension = os.path.splitext(file.name)[1].lower()
26
  with tempfile.NamedTemporaryFile(mode="wb", delete=False, suffix=file_extension) as temp_file:
27
  self.temp_file_path = temp_file.name
 
36
  raise ValueError(
37
  f"Unsupported file type: {self.temp_file_path}"
38
  )
39
+ print(self.documents)
40
+ return text_splitter.split_text(self.documents)
41
  else:
42
  raise ValueError(
43
  "Not a file"
44
  )
45
 
46
+
47
  def load_text_file(self):
48
  with open(self.temp_file_path, "r", encoding=self.encoding) as f:
49
  self.documents.append(f.read())
 
55
  for page_num in range(len(pdf_document)):
56
  page = pdf_document.load_page(page_num)
57
  text = page.get_text()
58
+ self.documents.append(text)
59
+
60
+ class CharacterTextSplitter:
61
+ def __init__(
62
+ self,
63
+ chunk_size: int = 1000,
64
+ chunk_overlap: int = 200,
65
+ ):
66
+ assert (
67
+ chunk_size > chunk_overlap
68
+ ), "Chunk size must be greater than chunk overlap"
69
+
70
+ self.chunk_size = chunk_size
71
+ self.chunk_overlap = chunk_overlap
72
+
73
+ def split(self, text: str) -> List[str]:
74
+ chunks = []
75
+ for i in range(0, len(text), self.chunk_size - self.chunk_overlap):
76
+ chunks.append(text[i : i + self.chunk_size])
77
+ return chunks
78
+
79
+ def split_text(self, texts: List[str]) -> List[str]:
80
+ chunks = []
81
+ for text in texts:
82
+ chunks.extend(self.split(text))
83
+ return chunks
84
+
85
+
86
+
87
+ class MyRecursiveCharacterTextSplitter:
88
+ def __init__(
89
+ self
90
+ ):
91
+ self.RCTS = RecursiveCharacterTextSplitter(
92
+ chunk_size=1000,
93
+ chunk_overlap=20,
94
+ length_function=len,
95
+ separators=["\n\n", "\n", " ", ""]
96
+ )
97
+
98
+ def split_text(self, texts: List[str]) -> List[str]:
99
+ all_chunks = []
100
+ for doc in texts:
101
+ chunks = self.RCTS.split_text(doc)
102
+ all_chunks.extend(chunks)
103
+ return all_chunks
104
+
105
+