Spaces:
Sleeping
Sleeping
import json | |
import requests | |
PREREQUISITE_PROMPT = """\ | |
あなたは採点者です。 | |
問題, 採点基準, 回答 が与えられます。 | |
回答を1,2,3,4,5の5段階で採点し、数字のみを出力してください。 | |
# 採点基準 | |
基本的な採点基準 | |
- 1点: 誤っている、 指示に従えていない | |
- 2点: 誤っているが、方向性は合っている | |
- 3点: 部分的に誤っている、 部分的に合っている | |
- 4点: 合っている | |
- 5点: 役に立つ | |
基本的な減点項目 | |
- 不自然な日本語: -1点 | |
- 部分的に事実と異なる内容を述べている: -1点 | |
""" | |
def evaluation_prompt( | |
input: str, output: str, eval_aspect: str | None, target: str | None | |
) -> str: | |
return f"""\ | |
回答を1,2,3,4,5の5段階で採点し、数字のみを出力してください。 | |
# 問題: {input} | |
{f"# 正解例: {target}" if target is not None else ""} | |
{f"# 採点基準: {eval_aspect}" if eval_aspect is not None else ""} | |
# 回答: {output} | |
""" | |
# GradioからのGemini SDKを用いた通信がいつまでも終わらないため、REST APIを利用する | |
def evaluate(results: list[dict], api_key: str, batch_size: int = 10) -> list[dict]: | |
url = f"https://generativelanguage.googleapis.com/v1beta/models/gemini-1.5-flash:generateContent?key={api_key}" | |
headers = {"Content-Type": "application/json"} | |
evaluations = [] | |
for i in range(0, len(results), batch_size): | |
batch_results = results[i : i + batch_size] | |
prompts = [ | |
evaluation_prompt( | |
result["input"], | |
result["output"], | |
result.get("eval_aspect"), | |
result.get("target"), | |
) | |
for result in batch_results | |
] | |
data = { | |
"contents": [{"parts": [{"text": "\n".join(prompts)}]}], | |
"generationConfig": { | |
"response_mime_type": "application/json", | |
"response_schema": {"type": "ARRAY", "items": {"type": "NUMBER"}}, | |
}, | |
} | |
response = requests.post(url, headers=headers, data=json.dumps(data)) | |
if response.status_code == 200: | |
response_data = response.json() | |
# Parse the response_data to extract the scores | |
scores = json.loads( | |
response_data["candidates"][0]["content"]["parts"][0]["text"] | |
) | |
else: | |
raise Exception( | |
f"API request failed with status code {response.status_code}: {response.text}" | |
) | |
for result, score in zip(batch_results, scores): | |
evaluations.append( | |
{ | |
"input": result["input"], | |
"output": result["output"], | |
"eval_aspect": result.get("eval_aspect"), | |
"target": result.get("target"), | |
"score": score, | |
} | |
) | |
return evaluations | |
def report(tasks: list[dict]) -> str: | |
return ( | |
"""\ | |
<!DOCTYPE html> | |
<html lang="ja"> | |
<head> | |
<meta charset="UTF-8"> | |
<meta name="viewport" content="width=device-width, initial-scale=1.0"> | |
<title>レポート</title> | |
<style> | |
body { | |
background-color: #f8f9fa; | |
} | |
.container { | |
width: 80%; /* 可変幅 */ | |
margin: 20px auto; | |
background-color: #ffffff; | |
border-radius: 8px; | |
} | |
.divider { | |
position: relative; | |
padding: 16px 0; | |
align-items: center; | |
justify-content: center; | |
} | |
.divider .line { | |
height: 1px; | |
background-color: #ddd; | |
} | |
.divider .taskName { | |
position: absolute; | |
margin: -8px; | |
left: 50%; | |
transform: translateX(-50%); | |
padding: 0 10px; | |
font-size: 14px; | |
font-weight: 900; | |
text-align: center; | |
border: 1px solid #ddd; | |
border-radius: 9999px; | |
background-color: #ffffff; | |
white-space: nowrap; | |
} | |
.message { | |
padding: 8px; | |
} | |
.content { | |
font-size: 14px; | |
font-weight: 400; | |
} | |
.from { | |
font-size: 14px; | |
font-weight: 900; | |
} | |
</style> | |
</head> | |
<body> | |
<div class="container" id="container"></div> | |
<script> | |
const messages = """ | |
+ json.dumps(tasks) | |
+ """; | |
// taskName: str | |
const createDivider = (taskName) => { | |
const divider = document.createElement('div'); | |
divider.classList.add('divider'); | |
const line = document.createElement('div'); | |
line.classList.add('line'); | |
const taskNameLabel = document.createElement('div'); | |
taskNameLabel.classList.add('taskName'); | |
taskNameLabel.textContent = taskName; | |
divider.appendChild(line); | |
divider.appendChild(taskNameLabel); | |
return divider; | |
}; | |
// task: HTMLDivElement, from: 'input' | 'output' | str, text: string | |
// return: HTMLDivElement | |
const createMessage = (text, name) => { | |
const message = document.createElement('div'); | |
message.classList.add('message'); | |
const from = document.createElement('div'); | |
from.classList.add('from'); | |
from.textContent = name; | |
const content = document.createElement('div'); | |
content.classList.add('content'); | |
content.innerHTML = text.replace(/\\n/g, '<br>'); | |
message.appendChild(from); | |
message.appendChild(content); | |
return message; | |
}; | |
const container = document.getElementById('container'); | |
messages.forEach((message, i) => { | |
const task = document.createElement('div'); | |
task.classList.add('task'); | |
task.appendChild(createDivider(message.task_id ? `Task ID: ${message.task_id}` : `Task Index ${i}`)); | |
task.appendChild(createMessage(message.input, 'input')); | |
task.appendChild(createMessage(message.output, 'output' + (message.score ? ` (score: ${message.score})` : ''))); | |
container.appendChild(task); | |
}); | |
</script> | |
</body> | |
</html> | |
""" | |
) | |