cmagganas commited on
Commit
a6e29fb
1 Parent(s): aa1c94a

Update build_langchain_vector_store.py

Browse files
Files changed (1) hide show
  1. build_langchain_vector_store.py +121 -0
build_langchain_vector_store.py CHANGED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Builds and persists a LangChain vector store over the Website documentation using Chroma.
4
+ Source: https://github.com/Arize-ai/phoenix/blob/main/scripts/data/build_langchain_vector_store.py
5
+ """
6
+
7
+ import argparse
8
+ import getpass
9
+ import logging
10
+ import shutil
11
+ import sys
12
+ from functools import partial
13
+ from typing import List
14
+
15
+ from langchain.docstore.document import Document as LangChainDocument
16
+ from langchain.document_loaders import GitbookLoader
17
+ from langchain.embeddings import OpenAIEmbeddings
18
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
19
+ from langchain.vectorstores import Chroma
20
+ from tiktoken import Encoding, encoding_for_model
21
+
22
+
23
+ def load_gitbook_docs(docs_url: str) -> List[LangChainDocument]:
24
+ """Loads documents from a Gitbook URL.
25
+
26
+ Args:
27
+ docs_url (str): URL to Gitbook docs.
28
+
29
+ Returns:
30
+ List[LangChainDocument]: List of documents in LangChain format.
31
+ """
32
+ loader = GitbookLoader(
33
+ docs_url,
34
+ load_all_paths=True,
35
+ )
36
+ return loader.load()
37
+
38
+
39
+ def tiktoken_len(text: str, tokenizer: Encoding) -> int:
40
+ """Returns the length of a text in tokens.
41
+
42
+ Args:
43
+ text (str): The text to tokenize and count.
44
+ tokenizer (tiktoken.Encoding): The tokenizer.
45
+
46
+ Returns:
47
+ int: The number of tokens in the text.
48
+ """
49
+
50
+ tokens = tokenizer.encode(text, disallowed_special=())
51
+ return len(tokens)
52
+
53
+
54
+ def chunk_docs(
55
+ documents: List[LangChainDocument],
56
+ tokenizer: Encoding,
57
+ chunk_size: int = 400,
58
+ chunk_overlap: int = 20,
59
+ ) -> List[LangChainDocument]:
60
+ """Chunks the documents.
61
+
62
+ The chunking strategy used in this function is from the following notebook and accompanying
63
+ video:
64
+
65
+ - https://github.com/pinecone-io/examples/blob/master/generation/langchain/handbook/
66
+ xx-langchain-chunking.ipynb
67
+ - https://www.youtube.com/watch?v=eqOfr4AGLk8
68
+
69
+ Args:
70
+ documents (List[LangChainDocument]): A list of input documents.
71
+
72
+ tokenizer (tiktoken.Encoding): The tokenizer used to count the number of tokens in a text.
73
+
74
+ chunk_size (int, optional): The size of the chunks in tokens.
75
+
76
+ chunk_overlap (int, optional): The chunk overlap in tokens.
77
+
78
+ Returns:
79
+ List[LangChainDocument]: The chunked documents.
80
+ """
81
+
82
+ text_splitter = RecursiveCharacterTextSplitter(
83
+ chunk_size=chunk_size,
84
+ chunk_overlap=chunk_overlap,
85
+ length_function=partial(tiktoken_len, tokenizer=tokenizer),
86
+ separators=["\n\n", "\n", " ", ""],
87
+ )
88
+ return text_splitter.split_documents(documents)
89
+
90
+
91
+ if __name__ == "__main__":
92
+ logging.basicConfig(level=logging.INFO, stream=sys.stdout)
93
+
94
+ parser = argparse.ArgumentParser()
95
+ parser.add_argument(
96
+ "--persist-path",
97
+ type=str,
98
+ required=False,
99
+ help="Path to persist index.",
100
+ default="langchain-chroma-pulze-docs",
101
+ )
102
+ args = parser.parse_args()
103
+
104
+ docs_url = "https://docs.pulze.ai/"
105
+ embedding_model_name = "text-embedding-ada-002"
106
+ langchain_documents = load_gitbook_docs(docs_url)
107
+ chunked_langchain_documents = chunk_docs(
108
+ langchain_documents,
109
+ tokenizer=encoding_for_model(embedding_model_name),
110
+ chunk_size=200,
111
+ )
112
+
113
+ embedding_model = OpenAIEmbeddings(model=embedding_model_name)
114
+ shutil.rmtree(args.persist_path, ignore_errors=True)
115
+ vector_store = Chroma.from_documents(
116
+ chunked_langchain_documents, embedding=embedding_model, persist_directory=args.persist_path
117
+ )
118
+ read_vector_store = Chroma(
119
+ persist_directory=args.persist_path, embedding_function=embedding_model
120
+ )
121
+ # print(read_vector_store.similarity_search("How do I use Pulze?"))