Woocy commited on
Commit
163b572
1 Parent(s): 710639d

Update llama_func.py

Browse files
Files changed (1) hide show
  1. llama_func.py +134 -98
llama_func.py CHANGED
@@ -1,6 +1,7 @@
1
  import os
2
  import logging
3
 
 
4
  from llama_index import download_loader
5
  from llama_index import (
6
  Document,
@@ -9,146 +10,181 @@ from llama_index import (
9
  QuestionAnswerPrompt,
10
  RefinePrompt,
11
  )
 
12
  import colorama
13
- # import PyPDF2
14
- from tqdm import tqdm
15
 
16
  from presets import *
17
  from utils import *
18
- from config import local_embedding
19
-
20
-
21
- def get_index_name(file_src):
22
- file_paths = [x.name for x in file_src]
23
- file_paths.sort(key=lambda x: os.path.basename(x))
24
-
25
- md5_hash = hashlib.md5()
26
- for file_path in file_paths:
27
- with open(file_path, "rb") as f:
28
- while chunk := f.read(8192):
29
- md5_hash.update(chunk)
30
-
31
- return md5_hash.hexdigest()
32
-
33
-
34
- def block_split(text):
35
- blocks = []
36
- while len(text) > 0:
37
- blocks.append(Document(text[:1000]))
38
- text = text[1000:]
39
- return blocks
40
 
41
 
42
  def get_documents(file_src):
43
  documents = []
 
44
  logging.debug("Loading documents...")
45
  logging.debug(f"file_src: {file_src}")
46
  for file in file_src:
47
- filepath = file.name
48
- filename = os.path.basename(filepath)
49
- file_type = os.path.splitext(filepath)[1]
50
- logging.info(f"loading file: {filename}")
51
- try:
52
- if file_type == ".pdf":
53
- logging.debug("Loading PDF...")
54
- CJKPDFReader = download_loader("CJKPDFReader")
55
- loader = CJKPDFReader()
56
- text_raw = loader.load_data(file=filepath)[0].text
57
- elif file_type == ".docx":
58
- logging.debug("Loading Word...")
59
- DocxReader = download_loader("DocxReader")
60
- loader = DocxReader()
61
- text_raw = loader.load_data(file=filepath)[0].text
62
- elif file_type == ".epub":
63
- logging.debug("Loading EPUB...")
64
- EpubReader = download_loader("EpubReader")
65
- loader = EpubReader()
66
- text_raw = loader.load_data(file=filepath)[0].text
67
- elif file_type == ".xlsx":
68
- logging.debug("Loading Excel...")
69
- text_list = excel_to_string(filepath)
70
- for elem in text_list:
71
- documents.append(Document(elem))
72
- continue
73
- else:
74
- logging.debug("Loading text file...")
75
- with open(filepath, "r", encoding="utf-8") as f:
76
- text_raw = f.read()
77
- except Exception as e:
78
- logging.error(f"Error loading file: {filename}")
79
- pass
80
- text = add_space(text_raw)
81
- # text = block_split(text)
82
- # documents += text
83
- documents += [Document(text)]
84
- logging.debug("Documents loaded.")
85
- return documents
86
 
87
 
88
  def construct_index(
89
  api_key,
90
  file_src,
91
  max_input_size=4096,
92
- num_outputs=5,
93
  max_chunk_overlap=20,
94
  chunk_size_limit=600,
95
  embedding_limit=None,
96
  separator=" ",
 
 
97
  ):
98
- from langchain.chat_models import ChatOpenAI
99
- from langchain.embeddings.huggingface import HuggingFaceEmbeddings
100
- from llama_index import GPTSimpleVectorIndex, ServiceContext, LangchainEmbedding, OpenAIEmbedding
101
-
102
- if api_key:
103
- os.environ["OPENAI_API_KEY"] = api_key
104
- else:
105
- # 由于一个依赖的愚蠢的设计,这里必须要有一个API KEY
106
- os.environ["OPENAI_API_KEY"] = "sk-xxxxxxx"
107
  chunk_size_limit = None if chunk_size_limit == 0 else chunk_size_limit
108
  embedding_limit = None if embedding_limit == 0 else embedding_limit
109
  separator = " " if separator == "" else separator
110
 
 
 
 
111
  prompt_helper = PromptHelper(
112
- max_input_size=max_input_size,
113
- num_output=num_outputs,
114
- max_chunk_overlap=max_chunk_overlap,
115
- embedding_limit=embedding_limit,
116
- chunk_size_limit=600,
117
  separator=separator,
118
  )
119
- index_name = get_index_name(file_src)
120
  if os.path.exists(f"./index/{index_name}.json"):
121
  logging.info("找到了缓存的索引文件,加载中……")
122
  return GPTSimpleVectorIndex.load_from_disk(f"./index/{index_name}.json")
123
  else:
124
  try:
125
- documents = get_documents(file_src)
126
- if local_embedding:
127
- embed_model = LangchainEmbedding(HuggingFaceEmbeddings(model_name = "sentence-transformers/distiluse-base-multilingual-cased-v2"))
128
- else:
129
- embed_model = OpenAIEmbedding()
130
- logging.info("构建索引中……")
131
- with retrieve_proxy():
132
- service_context = ServiceContext.from_defaults(
133
- prompt_helper=prompt_helper,
134
- chunk_size_limit=chunk_size_limit,
135
- embed_model=embed_model,
136
- )
137
- index = GPTSimpleVectorIndex.from_documents(
138
- documents, service_context=service_context
139
- )
140
- logging.debug("索引构建完成!")
141
  os.makedirs("./index", exist_ok=True)
142
  index.save_to_disk(f"./index/{index_name}.json")
143
- logging.debug("索引已保存至本地!")
144
  return index
145
-
146
  except Exception as e:
147
- logging.error("索引构建失败!", e)
148
  print(e)
149
  return None
150
 
151
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
152
  def add_space(text):
153
  punctuations = {",": ", ", "。": "。 ", "?": "? ", "!": "! ", ":": ": ", ";": "; "}
154
  for cn_punc, en_punc in punctuations.items():
 
1
  import os
2
  import logging
3
 
4
+ from llama_index import GPTSimpleVectorIndex
5
  from llama_index import download_loader
6
  from llama_index import (
7
  Document,
 
10
  QuestionAnswerPrompt,
11
  RefinePrompt,
12
  )
13
+ from langchain.llms import OpenAI
14
  import colorama
15
+
 
16
 
17
  from presets import *
18
  from utils import *
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
 
20
 
21
  def get_documents(file_src):
22
  documents = []
23
+ index_name = ""
24
  logging.debug("Loading documents...")
25
  logging.debug(f"file_src: {file_src}")
26
  for file in file_src:
27
+ logging.debug(f"file: {file.name}")
28
+ index_name += file.name
29
+ if os.path.splitext(file.name)[1] == ".pdf":
30
+ logging.debug("Loading PDF...")
31
+ CJKPDFReader = download_loader("CJKPDFReader")
32
+ loader = CJKPDFReader()
33
+ documents += loader.load_data(file=file.name)
34
+ elif os.path.splitext(file.name)[1] == ".docx":
35
+ logging.debug("Loading DOCX...")
36
+ DocxReader = download_loader("DocxReader")
37
+ loader = DocxReader()
38
+ documents += loader.load_data(file=file.name)
39
+ elif os.path.splitext(file.name)[1] == ".epub":
40
+ logging.debug("Loading EPUB...")
41
+ EpubReader = download_loader("EpubReader")
42
+ loader = EpubReader()
43
+ documents += loader.load_data(file=file.name)
44
+ else:
45
+ logging.debug("Loading text file...")
46
+ with open(file.name, "r", encoding="utf-8") as f:
47
+ text = add_space(f.read())
48
+ documents += [Document(text)]
49
+ index_name = sha1sum(index_name)
50
+ return documents, index_name
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
 
52
 
53
  def construct_index(
54
  api_key,
55
  file_src,
56
  max_input_size=4096,
57
+ num_outputs=1,
58
  max_chunk_overlap=20,
59
  chunk_size_limit=600,
60
  embedding_limit=None,
61
  separator=" ",
62
+ num_children=10,
63
+ max_keywords_per_chunk=10,
64
  ):
65
+ os.environ["OPENAI_API_KEY"] = api_key
 
 
 
 
 
 
 
 
66
  chunk_size_limit = None if chunk_size_limit == 0 else chunk_size_limit
67
  embedding_limit = None if embedding_limit == 0 else embedding_limit
68
  separator = " " if separator == "" else separator
69
 
70
+ llm_predictor = LLMPredictor(
71
+ llm=OpenAI(model_name="gpt-3.5-turbo-0301", openai_api_key=api_key)
72
+ )
73
  prompt_helper = PromptHelper(
74
+ max_input_size,
75
+ num_outputs,
76
+ max_chunk_overlap,
77
+ embedding_limit,
78
+ chunk_size_limit,
79
  separator=separator,
80
  )
81
+ documents, index_name = get_documents(file_src)
82
  if os.path.exists(f"./index/{index_name}.json"):
83
  logging.info("找到了缓存的索引文件,加载中……")
84
  return GPTSimpleVectorIndex.load_from_disk(f"./index/{index_name}.json")
85
  else:
86
  try:
87
+ logging.debug("构建索引中……")
88
+ index = GPTSimpleVectorIndex(
89
+ documents, llm_predictor=llm_predictor, prompt_helper=prompt_helper
90
+ )
 
 
 
 
 
 
 
 
 
 
 
 
91
  os.makedirs("./index", exist_ok=True)
92
  index.save_to_disk(f"./index/{index_name}.json")
 
93
  return index
 
94
  except Exception as e:
 
95
  print(e)
96
  return None
97
 
98
 
99
+ def chat_ai(
100
+ api_key,
101
+ index,
102
+ question,
103
+ context,
104
+ chatbot,
105
+ ):
106
+ os.environ["OPENAI_API_KEY"] = api_key
107
+
108
+ logging.info(f"Question: {question}")
109
+
110
+ response, chatbot_display, status_text = ask_ai(
111
+ api_key,
112
+ index,
113
+ question,
114
+ replace_today(PROMPT_TEMPLATE),
115
+ REFINE_TEMPLATE,
116
+ SIM_K,
117
+ INDEX_QUERY_TEMPRATURE,
118
+ context,
119
+ )
120
+ if response is None:
121
+ status_text = "查询失败,请换个问法试试"
122
+ return context, chatbot
123
+ response = response
124
+
125
+ context.append({"role": "user", "content": question})
126
+ context.append({"role": "assistant", "content": response})
127
+ chatbot.append((question, chatbot_display))
128
+
129
+ os.environ["OPENAI_API_KEY"] = ""
130
+ return context, chatbot, status_text
131
+
132
+
133
+ def ask_ai(
134
+ api_key,
135
+ index,
136
+ question,
137
+ prompt_tmpl,
138
+ refine_tmpl,
139
+ sim_k=1,
140
+ temprature=0,
141
+ prefix_messages=[],
142
+ ):
143
+ os.environ["OPENAI_API_KEY"] = api_key
144
+
145
+ logging.debug("Index file found")
146
+ logging.debug("Querying index...")
147
+ llm_predictor = LLMPredictor(
148
+ llm=OpenAI(
149
+ temperature=temprature,
150
+ model_name="gpt-3.5-turbo-0301",
151
+ prefix_messages=prefix_messages,
152
+ )
153
+ )
154
+
155
+ response = None # Initialize response variable to avoid UnboundLocalError
156
+ qa_prompt = QuestionAnswerPrompt(prompt_tmpl)
157
+ rf_prompt = RefinePrompt(refine_tmpl)
158
+ response = index.query(
159
+ question,
160
+ llm_predictor=llm_predictor,
161
+ similarity_top_k=sim_k,
162
+ text_qa_template=qa_prompt,
163
+ refine_template=rf_prompt,
164
+ response_mode="compact",
165
+ )
166
+
167
+ if response is not None:
168
+ logging.info(f"Response: {response}")
169
+ ret_text = response.response
170
+ nodes = []
171
+ for index, node in enumerate(response.source_nodes):
172
+ brief = node.source_text[:25].replace("\n", "")
173
+ nodes.append(
174
+ f"<details><summary>[{index+1}]\t{brief}...</summary><p>{node.source_text}</p></details>"
175
+ )
176
+ new_response = ret_text + "\n----------\n" + "\n\n".join(nodes)
177
+ logging.info(
178
+ f"Response: {colorama.Fore.BLUE}{ret_text}{colorama.Style.RESET_ALL}"
179
+ )
180
+ os.environ["OPENAI_API_KEY"] = ""
181
+ return ret_text, new_response, f"查询消耗了{llm_predictor.last_token_usage} tokens"
182
+ else:
183
+ logging.warning("No response found, returning None")
184
+ os.environ["OPENAI_API_KEY"] = ""
185
+ return None
186
+
187
+
188
  def add_space(text):
189
  punctuations = {",": ", ", "。": "。 ", "?": "? ", "!": "! ", ":": ": ", ";": "; "}
190
  for cn_punc, en_punc in punctuations.items():