Update README.md
Browse files
README.md
CHANGED
@@ -1,66 +1,200 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
4 |
-
- google/gemma-2-9b
|
5 |
-
- testmoto/gemma-2-9b-platypus-02
|
6 |
-
- testmoto/gemma-2-9b-synthetic_coding
|
7 |
-
- testmoto/gemma-2-9b-lora-0
|
8 |
-
library_name: transformers
|
9 |
-
tags:
|
10 |
-
- mergekit
|
11 |
-
- merge
|
12 |
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
|
|
|
|
|
|
17 |
|
18 |
-
|
19 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
20 |
|
21 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
22 |
|
23 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
24 |
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
* ./fused_model
|
29 |
-
* [testmoto/gemma-2-9b-synthetic_coding](https://huggingface.co/testmoto/gemma-2-9b-synthetic_coding)
|
30 |
-
* [testmoto/gemma-2-9b-lora-0](https://huggingface.co/testmoto/gemma-2-9b-lora-0)
|
31 |
|
32 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
|
34 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
35 |
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
65 |
|
66 |
-
```
|
|
|
1 |
+
# 推論用コード
|
2 |
+
Hugging Faceにアップロードしたモデルを用いてELYZA-tasks-100-TVの出力を得るためのコードです。
|
3 |
+
このコードで生成されたjsonlファイルは課題の成果として提出可能なフォーマットになっております。
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
4 |
|
5 |
+
```
|
6 |
+
!pip install -U bitsandbytes
|
7 |
+
!pip install -U transformers
|
8 |
+
!pip install -U accelerate
|
9 |
+
!pip install -U datasets
|
10 |
+
!pip install -U peft
|
11 |
+
```
|
12 |
|
13 |
+
```
|
14 |
+
import torch
|
15 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM
|
16 |
+
import json
|
17 |
+
from pathlib import Path
|
18 |
+
from typing import Dict, Any, Optional
|
19 |
+
from tqdm import tqdm
|
20 |
+
import time
|
21 |
+
from datetime import datetime
|
22 |
|
23 |
+
class GPUPredictions:
|
24 |
+
def __init__(self,
|
25 |
+
model_id="testmoto/gemma-2-llm2024-01",
|
26 |
+
adapter_path=None,
|
27 |
+
max_tokens=1024,
|
28 |
+
temp=0.0,
|
29 |
+
top_p=0.9,
|
30 |
+
seed=3407):
|
31 |
+
self.model_id = model_id
|
32 |
+
self.adapter_path = adapter_path
|
33 |
+
self.max_tokens = max_tokens
|
34 |
+
self.temp = temp
|
35 |
+
self.top_p = top_p
|
36 |
+
self.seed = seed
|
37 |
+
|
38 |
+
print(f"Loading model: {model_id}")
|
39 |
+
torch.cuda.empty_cache()
|
40 |
+
|
41 |
+
# GPU設定
|
42 |
+
n_gpus = torch.cuda.device_count()
|
43 |
+
max_memory = {i: "20GiB" for i in range(n_gpus)}
|
44 |
+
max_memory["cpu"] = "100GiB"
|
45 |
+
|
46 |
+
# トークナイザーの初期化
|
47 |
+
self.tokenizer = AutoTokenizer.from_pretrained(
|
48 |
+
model_id,
|
49 |
+
trust_remote_code=True
|
50 |
+
)
|
51 |
+
if self.tokenizer.pad_token is None:
|
52 |
+
self.tokenizer.pad_token = self.tokenizer.eos_token
|
53 |
+
|
54 |
+
try:
|
55 |
+
self.model = AutoModelForCausalLM.from_pretrained(
|
56 |
+
model_id,
|
57 |
+
torch_dtype=torch.float16,
|
58 |
+
device_map="auto",
|
59 |
+
max_memory=max_memory,
|
60 |
+
low_cpu_mem_usage=True,
|
61 |
+
trust_remote_code=True
|
62 |
+
)
|
63 |
+
except Exception as e:
|
64 |
+
print(f"First loading attempt failed: {str(e)}")
|
65 |
+
print("Trying alternative loading method...")
|
66 |
+
self.model = AutoModelForCausalLM.from_pretrained(
|
67 |
+
model_id,
|
68 |
+
torch_dtype=torch.float16,
|
69 |
+
device_map="balanced",
|
70 |
+
low_cpu_mem_usage=True,
|
71 |
+
trust_remote_code=True
|
72 |
+
)
|
73 |
+
|
74 |
+
if adapter_path:
|
75 |
+
print(f"Loading adapter from {adapter_path}")
|
76 |
+
self.model.load_adapter(adapter_path)
|
77 |
|
78 |
+
# Generate設定
|
79 |
+
self.gen_config = {
|
80 |
+
"max_new_tokens": max_tokens,
|
81 |
+
"temperature": temp,
|
82 |
+
"top_p": top_p,
|
83 |
+
"do_sample": temp > 0,
|
84 |
+
"pad_token_id": self.tokenizer.pad_token_id,
|
85 |
+
"eos_token_id": self.tokenizer.eos_token_id
|
86 |
+
}
|
87 |
|
88 |
+
print("Model loaded successfully")
|
89 |
+
self.device = next(self.model.parameters()).device
|
90 |
+
print(f"Model is on device: {self.device}")
|
|
|
|
|
|
|
91 |
|
92 |
+
@torch.inference_mode()
|
93 |
+
def generate_response(self, prompt: str) -> str:
|
94 |
+
"""効率的な応答生成"""
|
95 |
+
try:
|
96 |
+
inputs = self.tokenizer(prompt, return_tensors="pt", padding=True)
|
97 |
+
inputs = {k: v.to(self.device) for k, v in inputs.items()}
|
98 |
+
|
99 |
+
with torch.cuda.amp.autocast():
|
100 |
+
outputs = self.model.generate(
|
101 |
+
**inputs,
|
102 |
+
**self.gen_config
|
103 |
+
)
|
104 |
+
|
105 |
+
response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
|
106 |
+
|
107 |
+
if prompt in response:
|
108 |
+
response = response[len(prompt):].strip()
|
109 |
+
|
110 |
+
return response
|
111 |
+
|
112 |
+
except Exception as e:
|
113 |
+
print(f"Error during generation: {str(e)}")
|
114 |
+
raise
|
115 |
|
116 |
+
def load_tasks(self, file_path: str) -> list:
|
117 |
+
"""ELYZAタスクの読み込み"""
|
118 |
+
datasets = []
|
119 |
+
with open(file_path, "r") as f:
|
120 |
+
for line in f:
|
121 |
+
if line.strip():
|
122 |
+
datasets.append(json.loads(line.strip()))
|
123 |
+
return datasets
|
124 |
|
125 |
+
def run_inference(self, input_file="elyza-tasks-100-TV_0.jsonl", output_file="gpu_results.jsonl"):
|
126 |
+
"""バッチ処理による効率的な推論実行"""
|
127 |
+
tasks = self.load_tasks(input_file)
|
128 |
+
results = []
|
129 |
+
|
130 |
+
start_time = time.time()
|
131 |
+
execution_date = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
132 |
+
print(f"Execution started at: {execution_date}")
|
133 |
+
print(f"Total tasks: {len(tasks)}")
|
134 |
+
print("-" * 50)
|
135 |
+
|
136 |
+
for task in tqdm(tasks, desc="Processing tasks"):
|
137 |
+
task_start_time = time.time()
|
138 |
+
|
139 |
+
prompt = f"""### Instruction:
|
140 |
+
{task['input']}
|
141 |
+
<eos>
|
142 |
+
### Response: """
|
143 |
+
|
144 |
+
try:
|
145 |
+
response = self.generate_response(prompt)
|
146 |
+
try:
|
147 |
+
answer = response.split('### Response: ')[-1]
|
148 |
+
except:
|
149 |
+
answer = response
|
150 |
+
|
151 |
+
task_end_time = time.time()
|
152 |
+
task_duration = task_end_time - task_start_time
|
153 |
+
|
154 |
+
result = {
|
155 |
+
"task_id": task["task_id"],
|
156 |
+
"input": task["input"],
|
157 |
+
"output": answer
|
158 |
+
}
|
159 |
+
results.append(result)
|
160 |
+
|
161 |
+
print(f"\nTask {task['task_id']} completed in {task_duration:.2f} seconds")
|
162 |
+
print(f"Input: {task['input'][:100]}...")
|
163 |
+
print(f"Output: {answer[:100]}...")
|
164 |
+
print("-" * 50)
|
165 |
+
|
166 |
+
with open(output_file, 'a', encoding='utf-8') as f:
|
167 |
+
json.dump(result, f, ensure_ascii=False)
|
168 |
+
f.write('\n')
|
169 |
+
|
170 |
+
if task["task_id"] % 5 == 0:
|
171 |
+
torch.cuda.empty_cache()
|
172 |
+
|
173 |
+
except Exception as e:
|
174 |
+
print(f"Error processing task {task['task_id']}: {str(e)}")
|
175 |
+
continue
|
176 |
+
|
177 |
+
total_time = time.time() - start_time
|
178 |
+
avg_time = total_time / len(tasks)
|
179 |
+
|
180 |
+
summary = {
|
181 |
+
"execution_date": execution_date,
|
182 |
+
"total_tasks": len(tasks),
|
183 |
+
"total_time": round(total_time, 2),
|
184 |
+
"average_time_per_task": round(avg_time, 2),
|
185 |
+
"model_id": self.model_id,
|
186 |
+
"adapter_used": self.adapter_path is not None
|
187 |
+
}
|
188 |
+
|
189 |
+
print("\nExecution Summary:")
|
190 |
+
print(f"Total execution time: {total_time:.2f} seconds")
|
191 |
+
print(f"Average time per task: {avg_time:.2f} seconds")
|
192 |
+
print(f"Results saved to: {output_file}")
|
193 |
+
|
194 |
+
summary_file = output_file.replace('.jsonl', '_summary.json')
|
195 |
+
with open(summary_file, 'w', encoding='utf-8') as f:
|
196 |
+
json.dump(summary, f, ensure_ascii=False, indent=2)
|
197 |
+
|
198 |
+
return results
|
199 |
+
```
|
200 |
|
|