ZoniaChatbot commited on
Commit
8730d86
1 Parent(s): 05e3b86

Upload chatpdf.py

Browse files
Files changed (1) hide show
  1. chatpdf.py +582 -0
chatpdf.py ADDED
@@ -0,0 +1,582 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import hashlib
3
+ import os
4
+ import re
5
+ from threading import Thread
6
+ from typing import Union, List
7
+
8
+ import jieba
9
+ import torch
10
+ from loguru import logger
11
+ from peft import PeftModel
12
+ from similarities import (
13
+ EnsembleSimilarity,
14
+ BertSimilarity,
15
+ BM25Similarity,
16
+ )
17
+ from similarities.similarity import SimilarityABC
18
+ from transformers import (
19
+ AutoModelForCausalLM,
20
+ AutoTokenizer,
21
+ TextIteratorStreamer,
22
+ GenerationConfig,
23
+ AutoModelForSequenceClassification,
24
+ )
25
+
26
+ jieba.setLogLevel("ERROR")
27
+
28
+ MODEL_CLASSES = {
29
+ "auto": (AutoModelForCausalLM, AutoTokenizer),
30
+ }
31
+
32
+ PROMPT_TEMPLATE1 = """Utiliza la siguiente información para responder a la pregunta del usuario.
33
+ Si no sabes la respuesta, di simplemente que no la sabes, no intentes inventarte una respuesta.
34
+
35
+ Contexto: {context_str}
36
+ Pregunta: {query_str}
37
+
38
+ Devuelve sólo la respuesta útil que aparece a continuación y nada más, y ésta debe estar en Español.
39
+ Respuesta útil:
40
+ """
41
+ PROMPT_TEMPLATE = """Basándose en la siguiente información conocida, responda a la pregunta del usuario de forma
42
+ concisa y profesional. Si no puede obtener una respuesta, diga «No se puede responder a la pregunta basándose en la
43
+ información conocida» o «No se proporciona suficiente información relevante», no está permitido añadir elementos
44
+ inventados en la respuesta.
45
+
46
+ Contenido conocido:
47
+ {context_str}
48
+
49
+ Pregunta:
50
+ {query_str}
51
+ """
52
+
53
+
54
+ class SentenceSplitter:
55
+ def __init__(self, chunk_size: int = 250, chunk_overlap: int = 50):
56
+ self.chunk_size = chunk_size
57
+ self.chunk_overlap = chunk_overlap
58
+
59
+ def split_text(self, text: str) -> List[str]:
60
+ if self._is_has_chinese(text):
61
+ return self._split_chinese_text(text)
62
+ else:
63
+ return self._split_english_text(text)
64
+
65
+ def _split_chinese_text(self, text: str) -> List[str]:
66
+ sentence_endings = {'\n', '。', '!', '?', ';', '…'} # puntuación al final de una frase
67
+ chunks, current_chunk = [], ''
68
+ for word in jieba.cut(text):
69
+ if len(current_chunk) + len(word) > self.chunk_size:
70
+ chunks.append(current_chunk.strip())
71
+ current_chunk = word
72
+ else:
73
+ current_chunk += word
74
+ if word[-1] in sentence_endings and len(current_chunk) > self.chunk_size - self.chunk_overlap:
75
+ chunks.append(current_chunk.strip())
76
+ current_chunk = ''
77
+ if current_chunk:
78
+ chunks.append(current_chunk.strip())
79
+ if self.chunk_overlap > 0 and len(chunks) > 1:
80
+ chunks = self._handle_overlap(chunks)
81
+ return chunks
82
+
83
+ def _split_english_text(self, text: str) -> List[str]:
84
+ # División de texto inglés por frases mediante expresiones regulares
85
+ sentences = re.split(r'(?<=[.!?])\s+', text.replace('\n', ' '))
86
+ chunks, current_chunk = [], ''
87
+ for sentence in sentences:
88
+ if len(current_chunk) + len(sentence) <= self.chunk_size or not current_chunk:
89
+ current_chunk += (' ' if current_chunk else '') + sentence
90
+ else:
91
+ chunks.append(current_chunk)
92
+ current_chunk = sentence
93
+ if current_chunk: # Add the last chunk
94
+ chunks.append(current_chunk)
95
+
96
+ if self.chunk_overlap > 0 and len(chunks) > 1:
97
+ chunks = self._handle_overlap(chunks)
98
+
99
+ return chunks
100
+
101
+ def _is_has_chinese(self, text: str) -> bool:
102
+ # check if contains chinese characters
103
+ if any("\u4e00" <= ch <= "\u9fff" for ch in text):
104
+ return True
105
+ else:
106
+ return False
107
+
108
+ def _handle_overlap(self, chunks: List[str]) -> List[str]:
109
+ # Tratamiento de los solapamientos entre bloques
110
+ overlapped_chunks = []
111
+ for i in range(len(chunks) - 1):
112
+ chunk = chunks[i] + ' ' + chunks[i + 1][:self.chunk_overlap]
113
+ overlapped_chunks.append(chunk.strip())
114
+ overlapped_chunks.append(chunks[-1])
115
+ return overlapped_chunks
116
+
117
+
118
+ class ChatPDF:
119
+ def __init__(
120
+ self,
121
+ similarity_model: SimilarityABC = None,
122
+ generate_model_type: str = "auto",
123
+ generate_model_name_or_path: str = "LenguajeNaturalAI/leniachat-qwen2-1.5B-v0",
124
+ lora_model_name_or_path: str = None,
125
+ corpus_files: Union[str, List[str]] = None,
126
+ save_corpus_emb_dir: str = "corpus_embs/",
127
+ device: str = None,
128
+ int8: bool = False,
129
+ int4: bool = False,
130
+ chunk_size: int = 250,
131
+ chunk_overlap: int = 0,
132
+ rerank_model_name_or_path: str = None,
133
+ enable_history: bool = False,
134
+ num_expand_context_chunk: int = 2,
135
+ similarity_top_k: int = 10,
136
+ rerank_top_k: int = 3
137
+ ):
138
+
139
+ if torch.cuda.is_available():
140
+ default_device = torch.device(0)
141
+ elif torch.backends.mps.is_available():
142
+ default_device = torch.device('cpu')
143
+ else:
144
+ default_device = torch.device('cpu')
145
+ self.device = device or default_device
146
+ if num_expand_context_chunk > 0 and chunk_overlap > 0:
147
+ logger.warning(f" 'num_expand_context_chunk' and 'chunk_overlap' cannot both be greater than zero. "
148
+ f" 'chunk_overlap' has been set to zero by default.")
149
+ chunk_overlap = 0
150
+ self.text_splitter = SentenceSplitter(chunk_size, chunk_overlap)
151
+ if similarity_model is not None:
152
+ self.sim_model = similarity_model
153
+ else:
154
+ m1 = BertSimilarity(model_name_or_path="sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2", device=self.device)
155
+ m2 = BM25Similarity()
156
+ default_sim_model = EnsembleSimilarity(similarities=[m1, m2], weights=[0.5, 0.5], c=2)
157
+ self.sim_model = default_sim_model
158
+ self.gen_model, self.tokenizer = self._init_gen_model(
159
+ generate_model_type,
160
+ generate_model_name_or_path,
161
+ peft_name=lora_model_name_or_path,
162
+ int8=int8,
163
+ int4=int4,
164
+ )
165
+ self.history = []
166
+ self.corpus_files = corpus_files
167
+ if corpus_files:
168
+ self.add_corpus(corpus_files)
169
+ self.save_corpus_emb_dir = save_corpus_emb_dir
170
+ if rerank_model_name_or_path is None:
171
+ rerank_model_name_or_path = "maidalun1020/bce-reranker-base_v1"
172
+ if rerank_model_name_or_path:
173
+ self.rerank_tokenizer = AutoTokenizer.from_pretrained(rerank_model_name_or_path)
174
+ self.rerank_model = AutoModelForSequenceClassification.from_pretrained(rerank_model_name_or_path)
175
+ self.rerank_model.to(self.device)
176
+ self.rerank_model.eval()
177
+ else:
178
+ self.rerank_model = None
179
+ self.rerank_tokenizer = None
180
+ self.enable_history = enable_history
181
+ self.similarity_top_k = similarity_top_k
182
+ self.num_expand_context_chunk = num_expand_context_chunk
183
+ self.rerank_top_k = rerank_top_k
184
+
185
+ def __str__(self):
186
+ return f"Similarity model: {self.sim_model}, Generate model: {self.gen_model}"
187
+
188
+ def _init_gen_model(
189
+ self,
190
+ gen_model_type: str,
191
+ gen_model_name_or_path: str,
192
+ peft_name: str = None,
193
+ int8: bool = False,
194
+ int4: bool = False,
195
+ ):
196
+ """Init generate model."""
197
+ if int8 or int4:
198
+ device_map = None
199
+ else:
200
+ device_map = "auto"
201
+ model_class, tokenizer_class = MODEL_CLASSES[gen_model_type]
202
+ tokenizer = tokenizer_class.from_pretrained(gen_model_name_or_path, trust_remote_code=True)
203
+ model = model_class.from_pretrained(
204
+ gen_model_name_or_path,
205
+ load_in_8bit=int8 if gen_model_type not in ['baichuan', 'chatglm'] else False,
206
+ load_in_4bit=int4 if gen_model_type not in ['baichuan', 'chatglm'] else False,
207
+ torch_dtype="auto",
208
+ device_map=device_map,
209
+ trust_remote_code=True,
210
+ )
211
+ if self.device == torch.device('cpu'):
212
+ model.float()
213
+ if gen_model_type in ['baichuan', 'chatglm']:
214
+ if int4:
215
+ model = model.quantize(4).cuda()
216
+ elif int8:
217
+ model = model.quantize(8).cuda()
218
+ try:
219
+ model.generation_config = GenerationConfig.from_pretrained(gen_model_name_or_path, trust_remote_code=True)
220
+ except Exception as e:
221
+ logger.warning(f"No se pudo cargar la configuración de generación desde {gen_model_name_or_path}, {e}")
222
+ if peft_name:
223
+ model = PeftModel.from_pretrained(
224
+ model,
225
+ peft_name,
226
+ torch_dtype="auto",
227
+ )
228
+ logger.info(f"Modelo peft cargado desde {peft_name}")
229
+ model.eval()
230
+ return model, tokenizer
231
+
232
+ def _get_chat_input(self):
233
+ messages = []
234
+ for conv in self.history:
235
+ if conv and len(conv) > 0 and conv[0]:
236
+ messages.append({'role': 'user', 'content': conv[0]})
237
+ if conv and len(conv) > 1 and conv[1]:
238
+ messages.append({'role': 'assistant', 'content': conv[1]})
239
+ input_ids = self.tokenizer.apply_chat_template(
240
+ conversation=messages,
241
+ tokenize=True,
242
+ add_generation_prompt=True,
243
+ return_tensors='pt'
244
+ )
245
+ return input_ids.to(self.gen_model.device)
246
+
247
+ @torch.inference_mode()
248
+ def stream_generate_answer(
249
+ self,
250
+ max_new_tokens=512,
251
+ temperature=0.7,
252
+ repetition_penalty=1.0,
253
+ context_len=2048
254
+ ):
255
+ streamer = TextIteratorStreamer(self.tokenizer, timeout=520.0, skip_prompt=True, skip_special_tokens=True)
256
+ input_ids = self._get_chat_input()
257
+ max_src_len = context_len - max_new_tokens - 8
258
+ input_ids = input_ids[-max_src_len:]
259
+ generation_kwargs = dict(
260
+ input_ids=input_ids,
261
+ max_new_tokens=max_new_tokens,
262
+ temperature=temperature,
263
+ do_sample=True,
264
+ repetition_penalty=repetition_penalty,
265
+ streamer=streamer,
266
+ )
267
+ thread = Thread(target=self.gen_model.generate, kwargs=generation_kwargs)
268
+ thread.start()
269
+
270
+ yield from streamer
271
+
272
+ def add_corpus(self, files: Union[str, List[str]]):
273
+ """Load document files."""
274
+ if isinstance(files, str):
275
+ files = [files]
276
+ for doc_file in files:
277
+ if doc_file.endswith('.pdf'):
278
+ corpus = self.extract_text_from_pdf(doc_file)
279
+ elif doc_file.endswith('.docx'):
280
+ corpus = self.extract_text_from_docx(doc_file)
281
+ elif doc_file.endswith('.md'):
282
+ corpus = self.extract_text_from_markdown(doc_file)
283
+ else:
284
+ corpus = self.extract_text_from_txt(doc_file)
285
+ full_text = '\n'.join(corpus)
286
+ chunks = self.text_splitter.split_text(full_text)
287
+ self.sim_model.add_corpus(chunks)
288
+ self.corpus_files = files
289
+ logger.debug(f"files: {files}, corpus size: {len(self.sim_model.corpus)}, top3: "
290
+ f"{list(self.sim_model.corpus.values())[:3]}")
291
+
292
+ @staticmethod
293
+ def get_file_hash(fpaths):
294
+ hasher = hashlib.md5()
295
+ target_file_data = bytes()
296
+ if isinstance(fpaths, str):
297
+ fpaths = [fpaths]
298
+ for fpath in fpaths:
299
+ with open(fpath, 'rb') as file:
300
+ chunk = file.read(1024 * 1024) # read only first 1MB
301
+ hasher.update(chunk)
302
+ target_file_data += chunk
303
+
304
+ hash_name = hasher.hexdigest()[:32]
305
+ return hash_name
306
+
307
+ @staticmethod
308
+ def extract_text_from_pdf(file_path: str):
309
+ """Extract text content from a PDF file."""
310
+ import PyPDF2
311
+ contents = []
312
+ with open(file_path, 'rb') as f:
313
+ pdf_reader = PyPDF2.PdfReader(f)
314
+ for page in pdf_reader.pages:
315
+ page_text = page.extract_text().strip()
316
+ raw_text = [text.strip() for text in page_text.splitlines() if text.strip()]
317
+ new_text = ''
318
+ for text in raw_text:
319
+ # Añadir un espacio antes de concatenar si new_text no está vacío
320
+ if new_text:
321
+ new_text += ' '
322
+ new_text += text
323
+ if text[-1] in ['.', '!', '?', '。', '!', '?', '…', ';', ';', ':', ':', '”', '’', ')', '】', '》', '」',
324
+ '』', '〕', '〉', '》', '〗', '〞', '〟', '»', '"', "'", ')', ']', '}']:
325
+ contents.append(new_text)
326
+ new_text = ''
327
+ if new_text:
328
+ contents.append(new_text)
329
+ return contents
330
+
331
+ @staticmethod
332
+ def extract_text_from_txt(file_path: str):
333
+ """Extract text content from a TXT file."""
334
+ with open(file_path, 'r', encoding='utf-8') as f:
335
+ contents = [text.strip() for text in f.readlines() if text.strip()]
336
+ return contents
337
+
338
+ @staticmethod
339
+ def extract_text_from_docx(file_path: str):
340
+ """Extract text content from a DOCX file."""
341
+ import docx
342
+ document = docx.Document(file_path)
343
+ contents = [paragraph.text.strip() for paragraph in document.paragraphs if paragraph.text.strip()]
344
+ return contents
345
+
346
+ @staticmethod
347
+ def extract_text_from_markdown(file_path: str):
348
+ """Extract text content from a Markdown file."""
349
+ import markdown
350
+ from bs4 import BeautifulSoup
351
+ with open(file_path, 'r', encoding='utf-8') as f:
352
+ markdown_text = f.read()
353
+ html = markdown.markdown(markdown_text)
354
+ soup = BeautifulSoup(html, 'html.parser')
355
+ contents = [text.strip() for text in soup.get_text().splitlines() if text.strip()]
356
+ return contents
357
+
358
+ @staticmethod
359
+ def _add_source_numbers(lst):
360
+ """Add source numbers to a list of strings."""
361
+ return [f'[{idx + 1}]\t "{item}"' for idx, item in enumerate(lst)]
362
+
363
+ def _get_reranker_score(self, query: str, reference_results: List[str]):
364
+ """Get reranker score."""
365
+ pairs = []
366
+ for reference in reference_results:
367
+ pairs.append([query, reference])
368
+ with torch.no_grad():
369
+ inputs = self.rerank_tokenizer(pairs, padding=True, truncation=True, return_tensors='pt', max_length=512)
370
+ inputs_on_device = {k: v.to(self.rerank_model.device) for k, v in inputs.items()}
371
+ scores = self.rerank_model(**inputs_on_device, return_dict=True).logits.view(-1, ).float()
372
+
373
+ return scores
374
+
375
+ def get_reference_results(self, query: str):
376
+ """
377
+ Get reference results.
378
+ 1. Similarity model get similar chunks
379
+ 2. Rerank similar chunks
380
+ 3. Expand reference context chunk
381
+ :param query:
382
+ :return:
383
+ """
384
+ reference_results = []
385
+ sim_contents = self.sim_model.most_similar(query, topn=self.similarity_top_k)
386
+ # Get reference results from corpus
387
+ hit_chunk_dict = dict()
388
+ for query_id, id_score_dict in sim_contents.items():
389
+ for corpus_id, s in id_score_dict.items():
390
+ hit_chunk = self.sim_model.corpus[corpus_id]
391
+ reference_results.append(hit_chunk)
392
+ hit_chunk_dict[corpus_id] = hit_chunk
393
+
394
+ if reference_results:
395
+ if self.rerank_model is not None:
396
+ # Rerank reference results
397
+ rerank_scores = self._get_reranker_score(query, reference_results)
398
+ logger.debug(f"rerank_scores: {rerank_scores}")
399
+ # Get rerank top k chunks
400
+ reference_results = [reference for reference, score in sorted(
401
+ zip(reference_results, rerank_scores), key=lambda x: x[1], reverse=True)][:self.rerank_top_k]
402
+ hit_chunk_dict = {corpus_id: hit_chunk for corpus_id, hit_chunk in hit_chunk_dict.items() if
403
+ hit_chunk in reference_results}
404
+ # Expand reference context chunk
405
+ if self.num_expand_context_chunk > 0:
406
+ new_reference_results = []
407
+ for corpus_id, hit_chunk in hit_chunk_dict.items():
408
+ expanded_reference = self.sim_model.corpus.get(corpus_id - 1, '') + hit_chunk
409
+ for i in range(self.num_expand_context_chunk):
410
+ expanded_reference += self.sim_model.corpus.get(corpus_id + i + 1, '')
411
+ new_reference_results.append(expanded_reference)
412
+ reference_results = new_reference_results
413
+ return reference_results
414
+
415
+ def predict_stream(
416
+ self,
417
+ query: str,
418
+ max_length: int = 512,
419
+ context_len: int = 2048,
420
+ temperature: float = 0.7,
421
+ ):
422
+ """Generate predictions stream."""
423
+ stop_str = self.tokenizer.eos_token if self.tokenizer.eos_token else "</s>"
424
+ if not self.enable_history:
425
+ self.history = []
426
+ if self.sim_model.corpus:
427
+ reference_results = self.get_reference_results(query)
428
+ if not reference_results:
429
+ yield 'No se ha proporcionado suficiente información relevante', reference_results
430
+ reference_results = self._add_source_numbers(reference_results)
431
+ context_str = '\n'.join(reference_results)[:]
432
+ #print("context_str: " , (context_len - len(PROMPT_TEMPLATE)))
433
+ prompt = PROMPT_TEMPLATE.format(context_str=context_str, query_str=query)
434
+ logger.debug(f"prompt: {prompt}")
435
+ else:
436
+ prompt = query
437
+ logger.debug(prompt)
438
+ self.history.append([prompt, ''])
439
+ response = ""
440
+ for new_text in self.stream_generate_answer(
441
+ max_new_tokens=max_length,
442
+ temperature=temperature,
443
+ context_len=context_len,
444
+ ):
445
+ if new_text != stop_str:
446
+ response += new_text
447
+ yield response
448
+
449
+ def predict(
450
+ self,
451
+ query: str,
452
+ max_length: int = 512,
453
+ context_len: int = 2048,
454
+ temperature: float = 0.7,
455
+ ):
456
+ """Query from corpus."""
457
+ reference_results = []
458
+ if not self.enable_history:
459
+ self.history = []
460
+ if self.sim_model.corpus:
461
+ reference_results = self.get_reference_results(query)
462
+
463
+ if not reference_results:
464
+ return 'No se ha proporcionado suficiente información relevante', reference_results
465
+ reference_results = self._add_source_numbers(reference_results)
466
+ #context_str = '\n'.join(reference_results) # Usa todos los fragmentos
467
+ context_st = '\n'.join(reference_results)[:(context_len - len(PROMPT_TEMPLATE))]
468
+ #print("Context: ", (context_len - len(PROMPT_TEMPLATE)))
469
+ print(".......................................................")
470
+ context_str = '\n'.join(reference_results)[:]
471
+ #print("context_str: ", context_str)
472
+ prompt = PROMPT_TEMPLATE.format(context_str=context_str, query_str=query)
473
+ logger.debug(f"prompt: {prompt}")
474
+ else:
475
+ prompt = query
476
+ self.history.append([prompt, ''])
477
+ response = ""
478
+ for new_text in self.stream_generate_answer(
479
+ max_new_tokens=max_length,
480
+ temperature=temperature,
481
+ context_len=context_len,
482
+ ):
483
+ response += new_text
484
+ response = response.strip()
485
+ self.history[-1][1] = response
486
+ return response, reference_results
487
+
488
+ def save_corpus_emb(self):
489
+ dir_name = self.get_file_hash(self.corpus_files)
490
+ save_dir = os.path.join(self.save_corpus_emb_dir, dir_name)
491
+ if hasattr(self.sim_model, 'save_corpus_embeddings'):
492
+ self.sim_model.save_corpus_embeddings(save_dir)
493
+ logger.debug(f"Saving corpus embeddings to {save_dir}")
494
+ return save_dir
495
+
496
+ def load_corpus_emb(self, emb_dir: str):
497
+ if hasattr(self.sim_model, 'load_corpus_embeddings'):
498
+ logger.debug(f"Loading corpus embeddings from {emb_dir}")
499
+ self.sim_model.load_corpus_embeddings(emb_dir)
500
+
501
+ def save_corpus_text(self):
502
+ if not self.corpus_files:
503
+ logger.warning("No hay archivos de corpus para guardar.")
504
+ return
505
+
506
+ corpus_text_file = os.path.join("corpus_embs/", "corpus_text.txt")
507
+
508
+ with open(corpus_text_file, 'w', encoding='utf-8') as f:
509
+ for chunk in self.sim_model.corpus.values():
510
+ f.write(chunk + "\n\n") # Añade dos saltos de línea entre chunks para mejor legibilidad
511
+
512
+ logger.info(f"Texto del corpus guardado en: {corpus_text_file}")
513
+ return corpus_text_file
514
+
515
+ def load_corpus_text(self, emb_dir: str):
516
+ corpus_text_file = os.path.join("corpus_embs/", "corpus_text.txt")
517
+ if os.path.exists(corpus_text_file):
518
+ with open(corpus_text_file, 'r', encoding='utf-8') as f:
519
+ corpus_text = f.read().split("\n\n") # Asumiendo que usamos dos saltos de línea como separador
520
+ self.sim_model.corpus = {i: chunk.strip() for i, chunk in enumerate(corpus_text) if chunk.strip()}
521
+ logger.info(f"Texto del corpus cargado desde: {corpus_text_file}")
522
+ else:
523
+ logger.warning(f"No se encontró el archivo de texto del corpus en: {corpus_text_file}")
524
+
525
+ if __name__ == "__main__":
526
+ parser = argparse.ArgumentParser()
527
+ parser.add_argument("--sim_model_name", type=str, default="sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2")
528
+ parser.add_argument("--gen_model_type", type=str, default="auto")
529
+ parser.add_argument("--gen_model_name", type=str, default="LenguajeNaturalAI/leniachat-qwen2-1.5B-v0")
530
+ parser.add_argument("--lora_model", type=str, default=None)
531
+ parser.add_argument("--rerank_model_name", type=str, default="maidalun1020/bce-reranker-base_v1")
532
+ parser.add_argument("--corpus_files", type=str, default="docs/corpus.txt")
533
+ parser.add_argument("--device", type=str, default=None)
534
+ parser.add_argument("--int4", action='store_true', help="use int4 quantization")
535
+ parser.add_argument("--int8", action='store_true', help="use int8 quantization")
536
+ parser.add_argument("--chunk_size", type=int, default=220)
537
+ parser.add_argument("--chunk_overlap", type=int, default=50)
538
+ parser.add_argument("--num_expand_context_chunk", type=int, default=2)
539
+ args = parser.parse_args()
540
+ print(args)
541
+ sim_model = BertSimilarity(model_name_or_path=args.sim_model_name, device=args.device)
542
+ m = ChatPDF(
543
+ similarity_model=sim_model,
544
+ generate_model_type=args.gen_model_type,
545
+ generate_model_name_or_path=args.gen_model_name,
546
+ lora_model_name_or_path=args.lora_model,
547
+ device=args.device,
548
+ int4=args.int4,
549
+ int8=args.int8,
550
+ chunk_size=args.chunk_size,
551
+ chunk_overlap=args.chunk_overlap,
552
+ corpus_files=args.corpus_files.split(','),
553
+ num_expand_context_chunk=args.num_expand_context_chunk,
554
+ rerank_model_name_or_path=args.rerank_model_name,
555
+ )
556
+ logger.info(f"chatpdf model: {m}")
557
+
558
+ # Comprobar si existen incrustaciones guardadas
559
+ dir_name = m.get_file_hash(args.corpus_files.split(','))
560
+ save_dir = os.path.join(m.save_corpus_emb_dir, dir_name)
561
+
562
+ if os.path.exists(save_dir):
563
+ # Cargar las incrustaciones guardadas
564
+ m.load_corpus_emb(save_dir)
565
+ print(f"Incrustaciones del corpus cargadas desde: {save_dir}")
566
+ else:
567
+ # Procesar el corpus y guardar las incrustaciones
568
+ m.add_corpus(args.corpus_files.split(','))
569
+ save_dir = m.save_corpus_emb()
570
+ # Guardar el texto del corpus
571
+ m.save_corpus_text()
572
+ print(f"Las incrustaciones del corpus se han guardado en: {save_dir}")
573
+
574
+ while True:
575
+ query = input("\nEnter a query: ")
576
+ if query == "exit":
577
+ break
578
+ if query.strip() == "":
579
+ continue
580
+ r, refs = m.predict(query)
581
+ print(r, refs)
582
+ print("\nRespuesta: ", r)