Marina Pliusnina commited on
Commit
c774338
1 Parent(s): c8bd9ca

adding number of chunks and context

Browse files
Files changed (2) hide show
  1. app.py +10 -2
  2. rag.py +21 -13
app.py CHANGED
@@ -37,13 +37,14 @@ def generate(prompt, model_parameters):
37
  )
38
 
39
 
40
- def submit_input(input_, max_new_tokens, repetition_penalty, top_k, top_p, do_sample, num_beams, temperature):
41
  if input_.strip() == "":
42
  gr.Warning("Not possible to inference an empty input")
43
  return None
44
 
45
 
46
  model_parameters = {
 
47
  "MAX_NEW_TOKENS": max_new_tokens,
48
  "REPETITION_PENALTY": repetition_penalty,
49
  "TOP_K": top_k,
@@ -109,6 +110,13 @@ def gradio_app():
109
 
110
  with gr.Row(variant="panel"):
111
  with gr.Accordion("Model parameters", open=False, visible=SHOW_MODEL_PARAMETERS_IN_UI):
 
 
 
 
 
 
 
112
  max_new_tokens = Slider(
113
  minimum=50,
114
  maximum=1000,
@@ -154,7 +162,7 @@ def gradio_app():
154
  label="Temperature"
155
  )
156
 
157
- parameters_compontents = [max_new_tokens, repetition_penalty, top_k, top_p, do_sample, num_beams, temperature]
158
 
159
  with gr.Column(variant="panel"):
160
  output = Textbox(
 
37
  )
38
 
39
 
40
+ def submit_input(input_, num_chunks, max_new_tokens, repetition_penalty, top_k, top_p, do_sample, num_beams, temperature):
41
  if input_.strip() == "":
42
  gr.Warning("Not possible to inference an empty input")
43
  return None
44
 
45
 
46
  model_parameters = {
47
+ "NUM_CHUNKS": num_chunks,
48
  "MAX_NEW_TOKENS": max_new_tokens,
49
  "REPETITION_PENALTY": repetition_penalty,
50
  "TOP_K": top_k,
 
110
 
111
  with gr.Row(variant="panel"):
112
  with gr.Accordion("Model parameters", open=False, visible=SHOW_MODEL_PARAMETERS_IN_UI):
113
+ num_chunks = Slider(
114
+ minimum=1,
115
+ maximum=6,
116
+ step=1,
117
+ value=4,
118
+ label="Number of chunks"
119
+ )
120
  max_new_tokens = Slider(
121
  minimum=50,
122
  maximum=1000,
 
162
  label="Temperature"
163
  )
164
 
165
+ parameters_compontents = [num_chunks, max_new_tokens, repetition_penalty, top_k, top_p, do_sample, num_beams, temperature]
166
 
167
  with gr.Column(variant="panel"):
168
  output = Textbox(
rag.py CHANGED
@@ -24,19 +24,11 @@ class RAG:
24
 
25
  logging.info("RAG loaded!")
26
 
27
- def get_context(self, instruction, number_of_contexts=3):
28
-
29
- context = ""
30
-
31
 
32
  documentos = self.vectore_store.similarity_search_with_score(instruction, k=number_of_contexts)
33
 
34
-
35
- for doc in documentos:
36
-
37
- context += doc[0].page_content
38
-
39
- return context
40
 
41
  def predict(self, instruction, context, model_parameters):
42
 
@@ -61,14 +53,30 @@ class RAG:
61
  response = requests.post(self.model_name, headers=headers, json=payload)
62
 
63
  return response.json()[0]["generated_text"].split("###")[-1][8:-1]
 
 
 
 
 
 
 
 
 
 
 
 
 
64
 
65
  def get_response(self, prompt: str, model_parameters: dict) -> str:
66
 
67
- context = self.get_context(prompt)
 
 
 
68
 
69
- response = self.predict(prompt, context, model_parameters)
70
 
71
  if not response:
72
  return self.NO_ANSWER_MESSAGE
73
 
74
- return response
 
24
 
25
  logging.info("RAG loaded!")
26
 
27
+ def get_context(self, instruction, number_of_contexts=4):
 
 
 
28
 
29
  documentos = self.vectore_store.similarity_search_with_score(instruction, k=number_of_contexts)
30
 
31
+ return documentos
 
 
 
 
 
32
 
33
  def predict(self, instruction, context, model_parameters):
34
 
 
53
  response = requests.post(self.model_name, headers=headers, json=payload)
54
 
55
  return response.json()[0]["generated_text"].split("###")[-1][8:-1]
56
+
57
+ def beautiful_context(self, docs):
58
+
59
+ text_context = ""
60
+
61
+ full_context = ""
62
+
63
+ for doc in docs:
64
+ text_context += doc[0].page_content
65
+ full_context += doc[0].page_content + "\n"
66
+ full_context += doc[0].metadata["Títol de la norma"] + "\n\n"
67
+
68
+ return text_context, full_context
69
 
70
  def get_response(self, prompt: str, model_parameters: dict) -> str:
71
 
72
+ docs = self.get_context(prompt, model_parameters["NUM_CHUNKS"])
73
+ text_context, full_context = beautiful_context(docs)
74
+
75
+ del model_parameters["NUM_CHUNKS"]
76
 
77
+ response = self.predict(prompt, text_context, model_parameters)
78
 
79
  if not response:
80
  return self.NO_ANSWER_MESSAGE
81
 
82
+ return response, full_context