PyroSama commited on
Commit
0bd4b9a
1 Parent(s): 7297579

Create utils.py

Browse files
Files changed (1) hide show
  1. utils.py +55 -0
utils.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain_huggingface import HuggingFaceEmbeddings
2
+ from transformers import AutoModel, AutoTokenizer, AutoModelForCausalLM
3
+ from langchain_community.vectorstores import Chroma
4
+ from langchain.schema import Document
5
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
6
+ from langchain_community.llms.huggingface_pipeline import HuggingFacePipeline
7
+ import torch
8
+
9
+ embedding_model_name = 'nomic-ai/nomic-embed-text-v1.5'
10
+
11
+ model_kwargs = {'device':'cuda' if torch.cuda.is_available() else 'cpu',"trust_remote_code": True}
12
+
13
+ embeddings = HuggingFaceEmbeddings(
14
+ model_name=embedding_model_name,
15
+ model_kwargs=model_kwargs
16
+ )
17
+
18
+ vectorstore = None
19
+
20
+
21
+
22
+ def read_file(data: str) -> Document:
23
+ f = open(data,'r')
24
+ content = f.read()
25
+ f.close()
26
+ doc = Document(page_content=content, metadata={"name": data.split('/')[-1]})
27
+ return doc
28
+
29
+ text_splitter = RecursiveCharacterTextSplitter(chunk_size=800, chunk_overlap=100)
30
+
31
+ def add_doc(data,vectorstore):
32
+ doc = read_file(data)
33
+ splits = text_splitter.split_documents([doc])
34
+ vectorstore = Chroma.from_documents(documents=splits, embedding=embeddings)
35
+ retriever = vectorstore.as_retriever(search_kwargs={'k':1})
36
+ return retriever, vectorstore
37
+
38
+ def delete_doc(delete_name,vectorstore):
39
+ delete_doc_ids = []
40
+ for idx,name in enumerate(vectorstore.get()['metadatas']):
41
+ if name['name'] == delete_name:
42
+ delete_doc_ids.append(vectorstore.get()['ids'][idx])
43
+ for id in delete_doc_ids:
44
+ vectorstore.delete(ids = id)
45
+ # vectorstore.persist()
46
+ retriever = vectorstore.as_retriever(search_kwargs={'k':1})
47
+ return retriever, vectorstore
48
+
49
+ def delete_all_doc(vectorstore):
50
+ delete_doc_ids = vectorstore.get()['ids']
51
+ for id in delete_doc_ids:
52
+ vectorstore.delete(ids = id)
53
+ # vectorstore.persist()
54
+ retriever = vectorstore.as_retriever(search_kwargs={'k':1})
55
+ return retriever, vectorstore