kaikaidai commited on
Commit
214129a
·
verified ·
1 Parent(s): 281eda1

Update gen_api_answer.py

Browse files
Files changed (1) hide show
  1. gen_api_answer.py +71 -14
gen_api_answer.py CHANGED
@@ -10,6 +10,8 @@ from prompts import (
10
  JUDGE_SYSTEM_PROMPT,
11
  PROMETHEUS_PROMPT,
12
  PROMETHEUS_PROMPT_WITH_REFERENCE,
 
 
13
  )
14
 
15
  # Initialize clients
@@ -18,10 +20,8 @@ openai_client = OpenAI()
18
  together_client = Together()
19
  hf_api_key = os.getenv("HF_API_KEY")
20
  cohere_client = cohere.ClientV2(os.getenv("CO_API_KEY"))
21
- huggingface_client = OpenAI(
22
- base_url="https://otb7jglxy6r37af6.us-east-1.aws.endpoints.huggingface.cloud/v1/",
23
- api_key=hf_api_key
24
- )
25
 
26
  def get_openai_response(model_name, prompt, system_prompt=JUDGE_SYSTEM_PROMPT, max_tokens=500, temperature=0):
27
  """Get response from OpenAI API"""
@@ -70,7 +70,7 @@ def get_together_response(model_name, prompt, system_prompt=JUDGE_SYSTEM_PROMPT,
70
  except Exception as e:
71
  return f"Error with Together model {model_name}: {str(e)}"
72
 
73
- def get_hf_response(model_name, prompt, max_tokens=500):
74
  """Get response from Hugging Face model"""
75
  try:
76
  headers = {
@@ -83,7 +83,8 @@ def get_hf_response(model_name, prompt, max_tokens=500):
83
  "inputs": prompt,
84
  "parameters": {
85
  "max_new_tokens": max_tokens,
86
- "return_full_text": False
 
87
  }
88
  }
89
 
@@ -96,6 +97,34 @@ def get_hf_response(model_name, prompt, max_tokens=500):
96
  except Exception as e:
97
  return f"Error with Hugging Face model {model_name}: {str(e)}"
98
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99
  def get_cohere_response(model_name, prompt, system_prompt=JUDGE_SYSTEM_PROMPT, max_tokens=500, temperature=0):
100
  """Get response from Cohere API"""
101
  try:
@@ -132,20 +161,23 @@ def get_model_response(
132
  api_model = model_info["api_model"]
133
  organization = model_info["organization"]
134
 
135
- # Determine if model is Prometheus
136
  is_prometheus = (organization == "Prometheus")
 
137
 
138
- # For non-Prometheus models, use the Judge system prompt
139
- system_prompt = None if is_prometheus else JUDGE_SYSTEM_PROMPT
140
 
141
  # Select the appropriate base prompt
142
- if use_reference:
 
 
143
  base_prompt = PROMETHEUS_PROMPT_WITH_REFERENCE
144
  else:
145
  base_prompt = PROMETHEUS_PROMPT
146
 
147
- # For non-Prometheus models, replace the specific instruction
148
- if not is_prometheus:
149
  base_prompt = base_prompt.replace(
150
  '3. The output format should look as follows: "Feedback: (write a feedback for criteria) [RESULT] (an integer number between 1 and 5)"',
151
  '3. Your output format should strictly adhere to JSON as follows: {{"feedback": "<write feedback>", "result": <numerical score>}}. Ensure the output is valid JSON, without additional formatting or explanations.'
@@ -177,8 +209,12 @@ def get_model_response(
177
  api_model, final_prompt, system_prompt, max_tokens, temperature
178
  )
179
  elif organization == "Prometheus":
180
- return get_hf_response(
181
- api_model, final_prompt, max_tokens
 
 
 
 
182
  )
183
  elif organization == "Cohere":
184
  return get_cohere_response(
@@ -269,4 +305,25 @@ def prometheus_parse_model_response(output):
269
 
270
  except Exception as e:
271
  print(f"Failed to parse response: {str(e)}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
272
  return "Error", f"Exception during parsing: {str(e)}"
 
10
  JUDGE_SYSTEM_PROMPT,
11
  PROMETHEUS_PROMPT,
12
  PROMETHEUS_PROMPT_WITH_REFERENCE,
13
+ ATLA_PROMPT,
14
+ ATLA_PROMPT_WITH_REFERENCE,
15
  )
16
 
17
  # Initialize clients
 
20
  together_client = Together()
21
  hf_api_key = os.getenv("HF_API_KEY")
22
  cohere_client = cohere.ClientV2(os.getenv("CO_API_KEY"))
23
+
24
+
 
 
25
 
26
  def get_openai_response(model_name, prompt, system_prompt=JUDGE_SYSTEM_PROMPT, max_tokens=500, temperature=0):
27
  """Get response from OpenAI API"""
 
70
  except Exception as e:
71
  return f"Error with Together model {model_name}: {str(e)}"
72
 
73
+ def get_prometheus_response(model_name, prompt, max_tokens=500, temperature=0.01): # temperature needs to be > 0 for hf to work
74
  """Get response from Hugging Face model"""
75
  try:
76
  headers = {
 
83
  "inputs": prompt,
84
  "parameters": {
85
  "max_new_tokens": max_tokens,
86
+ "return_full_text": False,
87
+ "temperature": temperature
88
  }
89
  }
90
 
 
97
  except Exception as e:
98
  return f"Error with Hugging Face model {model_name}: {str(e)}"
99
 
100
+ def get_atla_response(model_name, prompt, max_tokens=500, temperature=0.01):
101
+ """Get response from HF endpoint for Atla model"""
102
+ try:
103
+ headers = {
104
+ "Accept": "application/json",
105
+ "Authorization": f"Bearer {hf_api_key}",
106
+ "Content-Type": "application/json"
107
+ }
108
+
109
+ payload = {
110
+ "inputs": prompt,
111
+ "parameters": {
112
+ "max_new_tokens": max_tokens,
113
+ "return_full_text": False,
114
+ "temperature": temperature,
115
+ "seed": 42
116
+ }
117
+ }
118
+
119
+ response = requests.post(
120
+ "https://azk0vbxyrc64s2v2.us-east-1.aws.endpoints.huggingface.cloud",
121
+ headers=headers,
122
+ json=payload
123
+ )
124
+ return response.json()[0]["generated_text"]
125
+ except Exception as e:
126
+ return f"Error with Atla model {model_name}: {str(e)}"
127
+
128
  def get_cohere_response(model_name, prompt, system_prompt=JUDGE_SYSTEM_PROMPT, max_tokens=500, temperature=0):
129
  """Get response from Cohere API"""
130
  try:
 
161
  api_model = model_info["api_model"]
162
  organization = model_info["organization"]
163
 
164
+ # Determine if model is Prometheus or Atla
165
  is_prometheus = (organization == "Prometheus")
166
+ is_atla = (organization == "Atla")
167
 
168
+ # For non-Prometheus/Atla models, use the Judge system prompt
169
+ system_prompt = None if (is_prometheus or is_atla) else JUDGE_SYSTEM_PROMPT
170
 
171
  # Select the appropriate base prompt
172
+ if is_atla:
173
+ base_prompt = ATLA_PROMPT_WITH_REFERENCE if use_reference else ATLA_PROMPT
174
+ elif use_reference:
175
  base_prompt = PROMETHEUS_PROMPT_WITH_REFERENCE
176
  else:
177
  base_prompt = PROMETHEUS_PROMPT
178
 
179
+ # For non-Prometheus/non-Atla models, replace the specific instruction
180
+ if not (is_prometheus or is_atla):
181
  base_prompt = base_prompt.replace(
182
  '3. The output format should look as follows: "Feedback: (write a feedback for criteria) [RESULT] (an integer number between 1 and 5)"',
183
  '3. Your output format should strictly adhere to JSON as follows: {{"feedback": "<write feedback>", "result": <numerical score>}}. Ensure the output is valid JSON, without additional formatting or explanations.'
 
209
  api_model, final_prompt, system_prompt, max_tokens, temperature
210
  )
211
  elif organization == "Prometheus":
212
+ return get_prometheus_response(
213
+ api_model, final_prompt, max_tokens, temperature = 0.01
214
+ )
215
+ elif organization == "Atla":
216
+ return get_atla_response(
217
+ api_model, final_prompt, max_tokens, temperature = 0.01
218
  )
219
  elif organization == "Cohere":
220
  return get_cohere_response(
 
305
 
306
  except Exception as e:
307
  print(f"Failed to parse response: {str(e)}")
308
+ return "Error", f"Exception during parsing: {str(e)}"
309
+
310
+ def atla_parse_model_response(output):
311
+ """Parse response from ATLA model"""
312
+ try:
313
+ print(f"Raw Atla model response: {output}")
314
+ output = output.strip()
315
+
316
+ # Look for the Reasoning and Result sections
317
+ reasoning_match = re.search(r'\*\*Reasoning:\*\*(.*?)(?=\*\*Result:|$)', output, re.DOTALL)
318
+ result_match = re.search(r'\*\*Result:\*\*\s*(\d+)', output)
319
+
320
+ if reasoning_match and result_match:
321
+ feedback = reasoning_match.group(1).strip()
322
+ score = result_match.group(1)
323
+ return str(score), feedback
324
+
325
+ return "Error", f"Failed to parse ATLA response format: {output}"
326
+
327
+ except Exception as e:
328
+ print(f"Failed to parse ATLA response: {str(e)}")
329
  return "Error", f"Exception during parsing: {str(e)}"