Woocy commited on
Commit
1c44e70
1 Parent(s): 4263807

Update llama_func.py

Browse files
Files changed (1) hide show
  1. llama_func.py +110 -136
llama_func.py CHANGED
@@ -1,7 +1,6 @@
1
  import os
2
  import logging
3
 
4
- from llama_index import GPTSimpleVectorIndex
5
  from llama_index import download_loader
6
  from llama_index import (
7
  Document,
@@ -10,181 +9,156 @@ from llama_index import (
10
  QuestionAnswerPrompt,
11
  RefinePrompt,
12
  )
13
- from langchain.llms import OpenAI
14
  import colorama
 
 
15
 
 
 
 
16
 
17
- from presets import *
18
- from utils import *
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
 
20
 
21
  def get_documents(file_src):
22
  documents = []
23
- index_name = ""
24
  logging.debug("Loading documents...")
25
  logging.debug(f"file_src: {file_src}")
26
  for file in file_src:
27
- logging.debug(f"file: {file.name}")
28
- index_name += file.name
29
- if os.path.splitext(file.name)[1] == ".pdf":
30
- logging.debug("Loading PDF...")
31
- CJKPDFReader = download_loader("CJKPDFReader")
32
- loader = CJKPDFReader()
33
- documents += loader.load_data(file=file.name)
34
- elif os.path.splitext(file.name)[1] == ".docx":
35
- logging.debug("Loading DOCX...")
36
- DocxReader = download_loader("DocxReader")
37
- loader = DocxReader()
38
- documents += loader.load_data(file=file.name)
39
- elif os.path.splitext(file.name)[1] == ".epub":
40
- logging.debug("Loading EPUB...")
41
- EpubReader = download_loader("EpubReader")
42
- loader = EpubReader()
43
- documents += loader.load_data(file=file.name)
44
- else:
45
- logging.debug("Loading text file...")
46
- with open(file.name, "r", encoding="utf-8") as f:
47
- text = add_space(f.read())
48
- documents += [Document(text)]
49
- index_name = sha1sum(index_name)
50
- return documents, index_name
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
 
52
 
53
  def construct_index(
54
  api_key,
55
  file_src,
56
  max_input_size=4096,
57
- num_outputs=1,
58
  max_chunk_overlap=20,
59
  chunk_size_limit=600,
60
  embedding_limit=None,
61
  separator=" ",
62
- num_children=10,
63
- max_keywords_per_chunk=10,
64
  ):
65
- os.environ["OPENAI_API_KEY"] = api_key
 
 
 
 
 
 
 
 
66
  chunk_size_limit = None if chunk_size_limit == 0 else chunk_size_limit
67
  embedding_limit = None if embedding_limit == 0 else embedding_limit
68
  separator = " " if separator == "" else separator
69
 
70
- llm_predictor = LLMPredictor(
71
- llm=OpenAI(model_name="gpt-3.5-turbo-0301", openai_api_key=api_key)
72
- )
73
  prompt_helper = PromptHelper(
74
- max_input_size,
75
- num_outputs,
76
- max_chunk_overlap,
77
- embedding_limit,
78
- chunk_size_limit,
79
  separator=separator,
80
  )
81
- documents, index_name = get_documents(file_src)
82
  if os.path.exists(f"./index/{index_name}.json"):
83
- logging.info("找到了缓存的索引文件,加载中…")
84
  return GPTSimpleVectorIndex.load_from_disk(f"./index/{index_name}.json")
85
  else:
86
  try:
87
- logging.debug("构建索引中……")
88
- index = GPTSimpleVectorIndex(
89
- documents, llm_predictor=llm_predictor, prompt_helper=prompt_helper
90
- )
 
 
 
 
 
 
 
 
 
 
 
 
91
  os.makedirs("./index", exist_ok=True)
92
  index.save_to_disk(f"./index/{index_name}.json")
 
93
  return index
 
94
  except Exception as e:
 
95
  print(e)
96
  return None
97
 
98
 
99
- def chat_ai(
100
- api_key,
101
- index,
102
- question,
103
- context,
104
- chatbot,
105
- ):
106
- os.environ["OPENAI_API_KEY"] = api_key
107
-
108
- logging.info(f"Question: {question}")
109
-
110
- response, chatbot_display, status_text = ask_ai(
111
- api_key,
112
- index,
113
- question,
114
- replace_today(PROMPT_TEMPLATE),
115
- REFINE_TEMPLATE,
116
- SIM_K,
117
- INDEX_QUERY_TEMPRATURE,
118
- context,
119
- )
120
- if response is None:
121
- status_text = "查询失败,请换个问法试试!"
122
- return context, chatbot
123
- response = response
124
-
125
- context.append({"role": "user", "content": question})
126
- context.append({"role": "assistant", "content": response})
127
- chatbot.append((question, chatbot_display))
128
-
129
- os.environ["OPENAI_API_KEY"] = ""
130
- return context, chatbot, status_text
131
-
132
-
133
- def ask_ai(
134
- api_key,
135
- index,
136
- question,
137
- prompt_tmpl,
138
- refine_tmpl,
139
- sim_k=1,
140
- temprature=0,
141
- prefix_messages=[],
142
- ):
143
- os.environ["OPENAI_API_KEY"] = api_key
144
-
145
- logging.debug("Index file found")
146
- logging.debug("Querying index...")
147
- llm_predictor = LLMPredictor(
148
- llm=OpenAI(
149
- temperature=temprature,
150
- model_name="gpt-3.5-turbo-0301",
151
- prefix_messages=prefix_messages,
152
- )
153
- )
154
-
155
- response = None # Initialize response variable to avoid UnboundLocalError
156
- qa_prompt = QuestionAnswerPrompt(prompt_tmpl)
157
- rf_prompt = RefinePrompt(refine_tmpl)
158
- response = index.query(
159
- question,
160
- llm_predictor=llm_predictor,
161
- similarity_top_k=sim_k,
162
- text_qa_template=qa_prompt,
163
- refine_template=rf_prompt,
164
- response_mode="compact",
165
- )
166
-
167
- if response is not None:
168
- logging.info(f"Response: {response}")
169
- ret_text = response.response
170
- nodes = []
171
- for index, node in enumerate(response.source_nodes):
172
- brief = node.source_text[:25].replace("\n", "")
173
- nodes.append(
174
- f"<details><summary>[{index+1}]\t{brief}...</summary><p>{node.source_text}</p></details>"
175
- )
176
- new_response = ret_text + "\n----------\n" + "\n\n".join(nodes)
177
- logging.info(
178
- f"Response: {colorama.Fore.BLUE}{ret_text}{colorama.Style.RESET_ALL}"
179
- )
180
- os.environ["OPENAI_API_KEY"] = ""
181
- return ret_text, new_response, f"查询消耗了{llm_predictor.last_token_usage} tokens"
182
- else:
183
- logging.warning("No response found, returning None")
184
- os.environ["OPENAI_API_KEY"] = ""
185
- return None
186
-
187
-
188
  def add_space(text):
189
  punctuations = {",": ", ", "。": "。 ", "?": "? ", "!": "! ", ":": ": ", ";": "; "}
190
  for cn_punc, en_punc in punctuations.items():
 
1
  import os
2
  import logging
3
 
 
4
  from llama_index import download_loader
5
  from llama_index import (
6
  Document,
 
9
  QuestionAnswerPrompt,
10
  RefinePrompt,
11
  )
 
12
  import colorama
13
+ import PyPDF2
14
+ from tqdm import tqdm
15
 
16
+ from modules.presets import *
17
+ from modules.utils import *
18
+ from modules.config import local_embedding
19
 
20
+
21
+ def get_index_name(file_src):
22
+ file_paths = [x.name for x in file_src]
23
+ file_paths.sort(key=lambda x: os.path.basename(x))
24
+
25
+ md5_hash = hashlib.md5()
26
+ for file_path in file_paths:
27
+ with open(file_path, "rb") as f:
28
+ while chunk := f.read(8192):
29
+ md5_hash.update(chunk)
30
+
31
+ return md5_hash.hexdigest()
32
+
33
+
34
+ def block_split(text):
35
+ blocks = []
36
+ while len(text) > 0:
37
+ blocks.append(Document(text[:1000]))
38
+ text = text[1000:]
39
+ return blocks
40
 
41
 
42
  def get_documents(file_src):
43
  documents = []
 
44
  logging.debug("Loading documents...")
45
  logging.debug(f"file_src: {file_src}")
46
  for file in file_src:
47
+ filepath = file.name
48
+ filename = os.path.basename(filepath)
49
+ file_type = os.path.splitext(filepath)[1]
50
+ logging.info(f"loading file: {filename}")
51
+ try:
52
+ if file_type == ".pdf":
53
+ logging.debug("Loading PDF...")
54
+ try:
55
+ from modules.pdf_func import parse_pdf
56
+ from modules.config import advance_docs
57
+
58
+ two_column = advance_docs["pdf"].get("two_column", False)
59
+ pdftext = parse_pdf(filepath, two_column).text
60
+ except:
61
+ pdftext = ""
62
+ with open(filepath, "rb") as pdfFileObj:
63
+ pdfReader = PyPDF2.PdfReader(pdfFileObj)
64
+ for page in tqdm(pdfReader.pages):
65
+ pdftext += page.extract_text()
66
+ text_raw = pdftext
67
+ elif file_type == ".docx":
68
+ logging.debug("Loading Word...")
69
+ DocxReader = download_loader("DocxReader")
70
+ loader = DocxReader()
71
+ text_raw = loader.load_data(file=filepath)[0].text
72
+ elif file_type == ".epub":
73
+ logging.debug("Loading EPUB...")
74
+ EpubReader = download_loader("EpubReader")
75
+ loader = EpubReader()
76
+ text_raw = loader.load_data(file=filepath)[0].text
77
+ elif file_type == ".xlsx":
78
+ logging.debug("Loading Excel...")
79
+ text_list = excel_to_string(filepath)
80
+ for elem in text_list:
81
+ documents.append(Document(elem))
82
+ continue
83
+ else:
84
+ logging.debug("Loading text file...")
85
+ with open(filepath, "r", encoding="utf-8") as f:
86
+ text_raw = f.read()
87
+ except Exception as e:
88
+ logging.error(f"Error loading file: {filename}")
89
+ pass
90
+ text = add_space(text_raw)
91
+ # text = block_split(text)
92
+ # documents += text
93
+ documents += [Document(text)]
94
+ logging.debug("Documents loaded.")
95
+ return documents
96
 
97
 
98
  def construct_index(
99
  api_key,
100
  file_src,
101
  max_input_size=4096,
102
+ num_outputs=5,
103
  max_chunk_overlap=20,
104
  chunk_size_limit=600,
105
  embedding_limit=None,
106
  separator=" ",
 
 
107
  ):
108
+ from langchain.chat_models import ChatOpenAI
109
+ from langchain.embeddings.huggingface import HuggingFaceEmbeddings
110
+ from llama_index import GPTSimpleVectorIndex, ServiceContext, LangchainEmbedding, OpenAIEmbedding
111
+
112
+ if api_key:
113
+ os.environ["OPENAI_API_KEY"] = api_key
114
+ else:
115
+ # 由于一个依赖的愚蠢的设计,这里必须要有一个API KEY
116
+ os.environ["OPENAI_API_KEY"] = "sk-xxxxxxx"
117
  chunk_size_limit = None if chunk_size_limit == 0 else chunk_size_limit
118
  embedding_limit = None if embedding_limit == 0 else embedding_limit
119
  separator = " " if separator == "" else separator
120
 
 
 
 
121
  prompt_helper = PromptHelper(
122
+ max_input_size=max_input_size,
123
+ num_output=num_outputs,
124
+ max_chunk_overlap=max_chunk_overlap,
125
+ embedding_limit=embedding_limit,
126
+ chunk_size_limit=600,
127
  separator=separator,
128
  )
129
+ index_name = get_index_name(file_src)
130
  if os.path.exists(f"./index/{index_name}.json"):
131
+ logging.info("找到了缓存的索引文件,加载中……")
132
  return GPTSimpleVectorIndex.load_from_disk(f"./index/{index_name}.json")
133
  else:
134
  try:
135
+ documents = get_documents(file_src)
136
+ if local_embedding:
137
+ embed_model = LangchainEmbedding(HuggingFaceEmbeddings(model_name = "sentence-transformers/distiluse-base-multilingual-cased-v2"))
138
+ else:
139
+ embed_model = OpenAIEmbedding()
140
+ logging.info("构建索引中……")
141
+ with retrieve_proxy():
142
+ service_context = ServiceContext.from_defaults(
143
+ prompt_helper=prompt_helper,
144
+ chunk_size_limit=chunk_size_limit,
145
+ embed_model=embed_model,
146
+ )
147
+ index = GPTSimpleVectorIndex.from_documents(
148
+ documents, service_context=service_context
149
+ )
150
+ logging.debug("索引构建完成!")
151
  os.makedirs("./index", exist_ok=True)
152
  index.save_to_disk(f"./index/{index_name}.json")
153
+ logging.debug("索引已保存至本地!")
154
  return index
155
+
156
  except Exception as e:
157
+ logging.error("索引构建失败!", e)
158
  print(e)
159
  return None
160
 
161
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
162
  def add_space(text):
163
  punctuations = {",": ", ", "。": "。 ", "?": "? ", "!": "! ", ":": ": ", ";": "; "}
164
  for cn_punc, en_punc in punctuations.items():