Spaces:
Running
Running
from openai import OpenAI | |
import anthropic | |
from together import Together | |
import cohere | |
import json | |
import re | |
import os | |
import requests | |
from prompts import ( | |
JUDGE_SYSTEM_PROMPT, | |
PROMETHEUS_PROMPT, | |
PROMETHEUS_PROMPT_WITH_REFERENCE, | |
ATLA_PROMPT, | |
ATLA_PROMPT_WITH_REFERENCE, | |
) | |
# Initialize clients | |
anthropic_client = anthropic.Anthropic() | |
openai_client = OpenAI() | |
together_client = Together() | |
hf_api_key = os.getenv("HF_API_KEY") | |
cohere_client = cohere.ClientV2(os.getenv("CO_API_KEY")) | |
def get_openai_response(model_name, prompt, system_prompt=JUDGE_SYSTEM_PROMPT, max_tokens=500, temperature=0): | |
"""Get response from OpenAI API""" | |
try: | |
response = openai_client.chat.completions.create( | |
model=model_name, | |
messages=[ | |
{"role": "system", "content": system_prompt}, | |
{"role": "user", "content": prompt}, | |
], | |
max_completion_tokens=max_tokens, | |
temperature=temperature, | |
) | |
return response.choices[0].message.content | |
except Exception as e: | |
return f"Error with OpenAI model {model_name}: {str(e)}" | |
def get_anthropic_response(model_name, prompt, system_prompt=JUDGE_SYSTEM_PROMPT, max_tokens=500, temperature=0): | |
"""Get response from Anthropic API""" | |
try: | |
response = anthropic_client.messages.create( | |
model=model_name, | |
max_tokens=max_tokens, | |
temperature=temperature, | |
system=system_prompt, | |
messages=[{"role": "user", "content": [{"type": "text", "text": prompt}]}], | |
) | |
return response.content[0].text | |
except Exception as e: | |
return f"Error with Anthropic model {model_name}: {str(e)}" | |
def get_together_response(model_name, prompt, system_prompt=JUDGE_SYSTEM_PROMPT, max_tokens=500, temperature=0): | |
"""Get response from Together API""" | |
try: | |
response = together_client.chat.completions.create( | |
model=model_name, | |
messages=[ | |
{"role": "system", "content": system_prompt}, | |
{"role": "user", "content": prompt}, | |
], | |
max_tokens=max_tokens, | |
temperature=temperature, | |
stream=False, | |
) | |
return response.choices[0].message.content | |
except Exception as e: | |
return f"Error with Together model {model_name}: {str(e)}" | |
def get_prometheus_response(model_name, prompt, max_tokens=500, temperature=0.01): # temperature needs to be > 0 for hf to work | |
"""Get response from Hugging Face model""" | |
try: | |
headers = { | |
"Accept": "application/json", | |
"Authorization": f"Bearer {hf_api_key}", | |
"Content-Type": "application/json" | |
} | |
payload = { | |
"inputs": prompt, | |
"parameters": { | |
"max_new_tokens": max_tokens, | |
"return_full_text": False, | |
"temperature": temperature | |
} | |
} | |
response = requests.post( | |
"https://otb7jglxy6r37af6.us-east-1.aws.endpoints.huggingface.cloud", | |
headers=headers, | |
json=payload | |
) | |
return response.json()[0]["generated_text"] | |
except Exception as e: | |
return f"Error with Hugging Face model {model_name}: {str(e)}" | |
def get_atla_response(model_name, prompt, max_tokens=500, temperature=0.01): | |
"""Get response from HF endpoint for Atla model""" | |
try: | |
headers = { | |
"Accept": "application/json", | |
"Authorization": f"Bearer {hf_api_key}", | |
"Content-Type": "application/json" | |
} | |
payload = { | |
"inputs": prompt, | |
"parameters": { | |
"max_new_tokens": max_tokens, | |
"return_full_text": False, | |
"temperature": temperature, | |
"seed": 42 | |
} | |
} | |
response = requests.post( | |
"https://azk0vbxyrc64s2v2.us-east-1.aws.endpoints.huggingface.cloud", | |
headers=headers, | |
json=payload | |
) | |
return response.json()[0]["generated_text"] | |
except Exception as e: | |
return f"Error with Atla model {model_name}: {str(e)}" | |
def get_cohere_response(model_name, prompt, system_prompt=JUDGE_SYSTEM_PROMPT, max_tokens=500, temperature=0): | |
"""Get response from Cohere API""" | |
try: | |
response = cohere_client.chat( | |
model=model_name, | |
messages=[ | |
{"role": "system", "content": system_prompt}, | |
{"role": "user", "content": prompt} | |
], | |
max_tokens=max_tokens, | |
temperature=temperature | |
) | |
# Extract the text from the content items | |
content_items = response.message.content | |
if isinstance(content_items, list): | |
# Get the text from the first content item | |
return content_items[0].text | |
return str(content_items) # Fallback if it's not a list | |
except Exception as e: | |
return f"Error with Cohere model {model_name}: {str(e)}" | |
def get_model_response( | |
model_name, | |
model_info, | |
prompt_data, | |
use_reference=False, | |
max_tokens=500, | |
temperature=0 | |
): | |
"""Get response from appropriate API based on model organization""" | |
if not model_info: | |
return "Model not found or unsupported." | |
api_model = model_info["api_model"] | |
organization = model_info["organization"] | |
# Determine if model is Prometheus or Atla | |
is_prometheus = (organization == "Prometheus") | |
is_atla = (organization == "Atla") | |
# For non-Prometheus/Atla models, use the Judge system prompt | |
system_prompt = None if (is_prometheus or is_atla) else JUDGE_SYSTEM_PROMPT | |
# Select the appropriate base prompt | |
if is_atla: | |
base_prompt = ATLA_PROMPT_WITH_REFERENCE if use_reference else ATLA_PROMPT | |
elif use_reference: | |
base_prompt = PROMETHEUS_PROMPT_WITH_REFERENCE | |
else: | |
base_prompt = PROMETHEUS_PROMPT | |
# For non-Prometheus/non-Atla models, replace the specific instruction | |
if not (is_prometheus or is_atla): | |
base_prompt = base_prompt.replace( | |
'3. The output format should look as follows: "Feedback: (write a feedback for criteria) [RESULT] (an integer number between 1 and 5)"', | |
'3. Your output format should strictly adhere to JSON as follows: {{"feedback": "<write feedback>", "result": <numerical score>}}. Ensure the output is valid JSON, without additional formatting or explanations.' | |
) | |
try: | |
# Format the prompt with the provided data, only using available keys | |
final_prompt = base_prompt.format( | |
human_input=prompt_data['human_input'], | |
ai_response=prompt_data['ai_response'], | |
ground_truth_input=prompt_data.get('ground_truth_input', ''), | |
eval_criteria=prompt_data['eval_criteria'], | |
score1_desc=prompt_data['score1_desc'], | |
score2_desc=prompt_data['score2_desc'], | |
score3_desc=prompt_data['score3_desc'], | |
score4_desc=prompt_data['score4_desc'], | |
score5_desc=prompt_data['score5_desc'] | |
) | |
except KeyError as e: | |
return f"Error formatting prompt: Missing required field {str(e)}" | |
try: | |
if organization == "OpenAI": | |
return get_openai_response( | |
api_model, final_prompt, system_prompt, max_tokens, temperature | |
) | |
elif organization == "Anthropic": | |
return get_anthropic_response( | |
api_model, final_prompt, system_prompt, max_tokens, temperature | |
) | |
elif organization == "Prometheus": | |
return get_prometheus_response( | |
api_model, final_prompt, max_tokens, temperature = 0.01 | |
) | |
elif organization == "Atla": | |
return get_atla_response( | |
api_model, final_prompt, max_tokens, temperature = 0.01 | |
) | |
elif organization == "Cohere": | |
return get_cohere_response( | |
api_model, final_prompt, system_prompt, max_tokens, temperature | |
) | |
else: | |
# All other organizations use Together API | |
return get_together_response( | |
api_model, final_prompt, system_prompt, max_tokens, temperature | |
) | |
except Exception as e: | |
return f"Error with {organization} model {model_name}: {str(e)}" | |
def parse_model_response(response): | |
try: | |
# Debug print | |
print(f"Raw model response: {response}") | |
# First try to parse the entire response as JSON | |
try: | |
data = json.loads(response) | |
return str(data.get("result", "N/A")), data.get("feedback", "N/A") | |
except json.JSONDecodeError: | |
# If that fails (typically for smaller models), try to find JSON within the response | |
json_match = re.search(r"{.*}", response, re.DOTALL) | |
if json_match: | |
data = json.loads(json_match.group(0)) | |
return str(data.get("result", "N/A")), data.get("feedback", "N/A") | |
else: | |
return "Error", f"Invalid response format returned - here is the raw model response: {response}" | |
except Exception as e: | |
# Debug print for error case | |
print(f"Failed to parse response: {str(e)}") | |
return "Error", f"Failed to parse response: {response}" | |
def prometheus_parse_model_response(output): | |
try: | |
print(f"Raw model response: {output}") | |
output = output.strip() | |
# Remove "Feedback:" prefix if present (case insensitive) | |
output = re.sub(r'^feedback:\s*', '', output, flags=re.IGNORECASE) | |
# New pattern to match [RESULT] X at the beginning | |
begin_result_pattern = r'^\[RESULT\]\s*(\d+)\s*\n*(.*?)$' | |
begin_match = re.search(begin_result_pattern, output, re.DOTALL | re.IGNORECASE) | |
if begin_match: | |
score = int(begin_match.group(1)) | |
feedback = begin_match.group(2).strip() | |
return str(score), feedback | |
# Existing patterns for end-of-string results... | |
pattern = r"(.*?)\s*\[RESULT\]\s*[\(\[]?(\d+)[\)\]]?" | |
match = re.search(pattern, output, re.DOTALL | re.IGNORECASE) | |
if match: | |
feedback = match.group(1).strip() | |
score = int(match.group(2)) | |
return str(score), feedback | |
# If no match, try to match "... Score: X" | |
pattern = r"(.*?)\s*(?:Score|Result)\s*:\s*[\(\[]?(\d+)[\)\]]?" | |
match = re.search(pattern, output, re.DOTALL | re.IGNORECASE) | |
if match: | |
feedback = match.group(1).strip() | |
score = int(match.group(2)) | |
return str(score), feedback | |
# Pattern to handle [Score X] at the end | |
pattern = r"(.*?)\s*\[(?:Score|Result)\s*[\(\[]?(\d+)[\)\]]?\]$" | |
match = re.search(pattern, output, re.DOTALL) | |
if match: | |
feedback = match.group(1).strip() | |
score = int(match.group(2)) | |
return str(score), feedback | |
# Final fallback attempt | |
pattern = r"[\(\[]?(\d+)[\)\]]?\s*\]?$" | |
match = re.search(pattern, output) | |
if match: | |
score = int(match.group(1)) | |
feedback = output[:match.start()].rstrip() | |
# Remove any trailing brackets from feedback | |
feedback = re.sub(r'\s*\[[^\]]*$', '', feedback).strip() | |
return str(score), feedback | |
return "Error", f"Failed to parse response: {output}" | |
except Exception as e: | |
print(f"Failed to parse response: {str(e)}") | |
return "Error", f"Exception during parsing: {str(e)}" | |
def atla_parse_model_response(output): | |
"""Parse response from ATLA model""" | |
try: | |
print(f"Raw Atla model response: {output}") | |
output = output.strip() | |
# Look for the Reasoning and Result sections | |
reasoning_match = re.search(r'\*\*Reasoning:\*\*(.*?)(?=\*\*Result:|$)', output, re.DOTALL) | |
result_match = re.search(r'\*\*Result:\*\*\s*(\d+)', output) | |
if reasoning_match and result_match: | |
feedback = reasoning_match.group(1).strip() | |
score = result_match.group(1) | |
return str(score), feedback | |
return "Error", f"Failed to parse ATLA response format: {output}" | |
except Exception as e: | |
print(f"Failed to parse ATLA response: {str(e)}") | |
return "Error", f"Exception during parsing: {str(e)}" |