PierreBrunelle commited on
Commit
75470e4
·
verified ·
1 Parent(s): 02840b1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +46 -49
app.py CHANGED
@@ -26,21 +26,6 @@ if 'FIREWORKS_API_KEY' not in os.environ:
26
  if 'MISTRAL_API_KEY' not in os.environ:
27
  os.environ['MISTRAL_API_KEY'] = getpass.getpass('Mistral AI API Key:')
28
 
29
- # Create prompt function
30
- @pxt.udf
31
- def create_prompt(top_k_list: list[dict], question: str) -> str:
32
- concat_top_k = '\n\n'.join(
33
- elt['text'] for elt in reversed(top_k_list)
34
- )
35
- return f'''
36
- PASSAGES:
37
-
38
- {concat_top_k}
39
-
40
- QUESTION:
41
-
42
- {question}'''
43
-
44
  """Gradio Application"""
45
  def process_files(ground_truth_file, pdf_files, chunk_limit, chunk_separator, show_question, show_correct_answer, show_gpt4omini, show_llamav3p23b, show_mistralsmall, progress=gr.Progress()):
46
  # Ensure a clean slate for the demo by removing and recreating the 'rag_demo' directory
@@ -86,6 +71,17 @@ def process_files(ground_truth_file, pdf_files, chunk_limit, chunk_separator, sh
86
  string_embed=sentence_transformer.using(model_id='sentence-transformers/all-MiniLM-L12-v2')
87
  )
88
 
 
 
 
 
 
 
 
 
 
 
 
89
  # Define a query function to retrieve the top-k most similar chunks for a given question
90
  @chunks_t.query
91
  def top_k(query_text: str):
@@ -96,52 +92,53 @@ def process_files(ground_truth_file, pdf_files, chunk_limit, chunk_separator, sh
96
  .limit(5)
97
  )
98
 
99
- # Add computed columns to the queries table for context retrieval and prompt creation
 
 
 
 
 
 
 
 
 
 
 
 
 
 
100
  queries_t.add_computed_column(question_context=chunks_t.queries.top_k(queries_t.question))
101
  queries_t.add_computed_column(prompt=create_prompt(
102
- queries_t.question_context, queries_t.question
 
103
  ))
104
-
105
- # Prepare messages for the OpenAI API, including system instructions and user prompt
106
- msgs = [
107
- {
108
- 'role': 'system',
109
- 'content': 'Read the following passages and answer the question based on their contents.'
110
- },
111
- {
112
- 'role': 'user',
113
- 'content': queries_t.prompt
114
- }
115
- ]
116
-
117
- progress(0.6, desc="Querying models...")
118
-
119
- # Add OpenAI response column
120
  queries_t.add_computed_column(response=openai.chat_completions(
121
  model='gpt-4o-mini-2024-07-18',
122
- messages=msgs,
123
  max_tokens=300,
124
  top_p=0.9,
125
  temperature=0.7
126
  ))
127
-
128
- # Create a table in Pixeltable and pick a model hosted on Anthropic with some parameters
129
  queries_t.add_computed_column(response_2=f_chat_completions(
130
- messages=msgs,
131
- model='accounts/fireworks/models/llama-v3p2-3b-instruct',
132
- # These parameters are optional and can be used to tune model behavior:
133
- max_tokens=300,
134
- top_p=0.9,
135
- temperature=0.7
136
  ))
137
-
138
  queries_t.add_computed_column(response_3=chat_completions(
139
- messages=msgs,
140
- model='mistral-small-latest',
141
- # These parameters are optional and can be used to tune model behavior:
142
- max_tokens=300,
143
- top_p=0.9,
144
- temperature=0.7
145
  ))
146
 
147
  # Extract the answer text from the API response
 
26
  if 'MISTRAL_API_KEY' not in os.environ:
27
  os.environ['MISTRAL_API_KEY'] = getpass.getpass('Mistral AI API Key:')
28
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
  """Gradio Application"""
30
  def process_files(ground_truth_file, pdf_files, chunk_limit, chunk_separator, show_question, show_correct_answer, show_gpt4omini, show_llamav3p23b, show_mistralsmall, progress=gr.Progress()):
31
  # Ensure a clean slate for the demo by removing and recreating the 'rag_demo' directory
 
71
  string_embed=sentence_transformer.using(model_id='sentence-transformers/all-MiniLM-L12-v2')
72
  )
73
 
74
+ # Create prompt function
75
+ @pxt.udf
76
+ def create_prompt(top_k_list: list[dict], question: str) -> str:
77
+ if not top_k_list:
78
+ return f"QUESTION:\n{question}"
79
+
80
+ concat_top_k = '\n\n'.join(
81
+ elt['text'] for elt in reversed(top_k_list) if elt and 'text' in elt
82
+ )
83
+ return f'''PASSAGES:\n{concat_top_k}\n\nQUESTION:\n{question}'''
84
+
85
  # Define a query function to retrieve the top-k most similar chunks for a given question
86
  @chunks_t.query
87
  def top_k(query_text: str):
 
92
  .limit(5)
93
  )
94
 
95
+ # Then modify the messages structure to use a UDF
96
+ @pxt.udf
97
+ def create_messages(prompt: str) -> list[dict]:
98
+ return [
99
+ {
100
+ 'role': 'system',
101
+ 'content': 'Read the following passages and answer the question based on their contents.'
102
+ },
103
+ {
104
+ 'role': 'user',
105
+ 'content': prompt
106
+ }
107
+ ]
108
+
109
+ # First add the context and prompt columns
110
  queries_t.add_computed_column(question_context=chunks_t.queries.top_k(queries_t.question))
111
  queries_t.add_computed_column(prompt=create_prompt(
112
+ queries_t.question_context,
113
+ queries_t.question
114
  ))
115
+
116
+ # Add the messages column
117
+ queries_t.add_computed_column(messages=create_messages(queries_t.prompt))
118
+
119
+ # Then add the response columns using the messages
 
 
 
 
 
 
 
 
 
 
 
120
  queries_t.add_computed_column(response=openai.chat_completions(
121
  model='gpt-4o-mini-2024-07-18',
122
+ messages=queries_t.messages,
123
  max_tokens=300,
124
  top_p=0.9,
125
  temperature=0.7
126
  ))
127
+
 
128
  queries_t.add_computed_column(response_2=f_chat_completions(
129
+ messages=queries_t.messages,
130
+ model='accounts/fireworks/models/llama-v3p2-3b-instruct',
131
+ max_tokens=300,
132
+ top_p=0.9,
133
+ temperature=0.7
 
134
  ))
135
+
136
  queries_t.add_computed_column(response_3=chat_completions(
137
+ messages=queries_t.messages,
138
+ model='mistral-small-latest',
139
+ max_tokens=300,
140
+ top_p=0.9,
141
+ temperature=0.7
 
142
  ))
143
 
144
  # Extract the answer text from the API response