|
--- |
|
base_model: |
|
- google/gemma-2-9b |
|
tags: |
|
- text-generation-inference |
|
- transformers |
|
- unsloth |
|
- gemma2 |
|
- trl |
|
license: gemma |
|
language: |
|
- en |
|
- ja |
|
datasets: |
|
- kanhatakeyama/wizardlm8x22b-logical-math-coding-sft_additional-ja |
|
- kanhatakeyama/AutoMultiTurnByCalm3-22B |
|
- kanhatakeyama/ramdom-to-fixed-multiturn-Calm3 |
|
--- |
|
|
|
# Model Card for Model ID |
|
Instruction tuning |
|
The models have been fine-tuned. |
|
|
|
Usage |
|
```python |
|
!pip install vllm==0.6.4.post1 --force-reinstall |
|
|
|
import time |
|
import torch |
|
import transformers |
|
from transformers import ( |
|
AutoTokenizer, |
|
AutoModelForCausalLM, |
|
) |
|
import vllm ### packaging==24.1にしないとエラーになる!! ### |
|
print(vllm.__version__) |
|
|
|
MAX_LENGTH = 1000 |
|
MODEL_NAME = "bay-llm/gemma-9b-SFT-1020-large-16bit" # コンペで提出したいモデルに適宜置換 |
|
|
|
llm = vllm.LLM( |
|
model=MODEL_NAME, |
|
tensor_parallel_size=1, |
|
gpu_memory_utilization=0.95, |
|
trust_remote_code=True, |
|
max_model_len=1024, |
|
|
|
) |
|
tokenizer = llm.get_tokenizer() |
|
|
|
# ELYZA-tasks-100-TVの読み込み。事前にファイルをアップロードしてください |
|
# データセットの読み込み。 |
|
# omnicampusの開発環境では、左にタスクのjsonlをドラッグアンドドロップしてから実行。 |
|
import json |
|
datasets = [] |
|
with open("../elyza-tasks-100-TV_0.jsonl", "r") as f: |
|
item = "" |
|
for line in f: |
|
line = line.strip() |
|
item += line |
|
if item.endswith("}"): |
|
datasets.append(json.loads(item)) |
|
item = "" |
|
|
|
print(datasets[0]) |
|
|
|
messages_list = [ |
|
[{"role": "user", "content": datasets[i]["input"]}] for i in range(len(datasets)) |
|
] |
|
|
|
prompts = [line[0]["content"] for line in messages_list] |
|
prompt_token_ids = [tokenizer.apply_chat_template(messages, add_generation_prompt=True) for messages in messages_list] |
|
sampling_params = vllm.SamplingParams( |
|
temperature=0.5, |
|
max_tokens=512, |
|
) |
|
outputs = llm.generate(prompt_token_ids=prompt_token_ids, sampling_params=sampling_params) |
|
for prompt, response in zip(prompts, outputs): |
|
print("prompt:", prompt) |
|
print("output:", response.outputs[0].text.strip()) |
|
print("-"*80) |
|
|
|
import json |
|
data = [{ |
|
"task_id": i, |
|
"input": prompts[i], |
|
"output": outputs[i].outputs[0].text.strip() |
|
} for i in range(len(datasets))] |
|
file_path = 'submmit.jsonl' |
|
with open(file_path, 'w', encoding='utf-8') as file: |
|
for entry in data: |
|
json.dump(entry, file, ensure_ascii=False) |
|
file.write('\n') |
|
|
|
``` |
|
|
|
# Uploaded model |
|
|
|
- **Developed by:** bay-llm |
|
- **License:** gemma |
|
- **Finetuned from model :** unsloth/gemma-2-9b-bnb-4bit |
|
|
|
This gemma2 model was trained 2x faster with [Unsloth](https://github.com/unslothai/unsloth) and Huggingface's TRL library. |
|
|
|
[<img src="https://raw.githubusercontent.com/unslothai/unsloth/main/images/unsloth%20made%20with%20love.png" width="200"/>](https://github.com/unslothai/unsloth) |