|
--- |
|
language: |
|
- ja |
|
base_model: |
|
- llm-jp/llm-jp-3-13b |
|
pipeline_tag: text-generation |
|
--- |
|
# **llm-jp-3-13b-finetune** |
|
|
|
## **概要** |
|
このモデルは、LLM-jp-3 13BをベースにLoRA (Low-Rank Adaptation) を用いてFine-tuningされたモデルです。 |
|
主に **ELYZA-tasks-100-TV** のタスクに対応するために調整されています。以下のような指示応答形式で動作します。 |
|
|
|
**入力形式**: |
|
指示 |
|
<ユーザーからの指示や質問> |
|
|
|
回答 |
|
"""" |
|
|
|
### 下記のversionで実行をしました。 |
|
``` |
|
Package Version |
|
--------------------------------------- -------------------- |
|
absl-py 1.4.0 |
|
accelerate 1.2.1 |
|
aiohappyeyeballs 2.4.3 |
|
aiohttp 3.11.7 |
|
aiosignal 1.3.1 |
|
annotated-types 0.7.0 |
|
antlr4-python3-runtime 4.9.3 |
|
anyio 4.6.2.post1 |
|
apex 0.1 |
|
argon2-cffi 21.3.0 |
|
argon2-cffi-bindings 21.2.0 |
|
asttokens 2.2.1 |
|
astunparse 1.6.3 |
|
async-timeout 4.0.2 |
|
attrs 24.2.0 |
|
audioread 3.0.0 |
|
backcall 0.2.0 |
|
beautifulsoup4 4.12.3 |
|
bitsandbytes 0.45.0 |
|
bleach 6.0.0 |
|
blis 0.7.9 |
|
cachetools 5.3.1 |
|
catalogue 2.0.8 |
|
certifi 2023.5.7 |
|
cffi 1.15.1 |
|
charset-normalizer 3.1.0 |
|
click 8.1.7 |
|
cloudpickle 2.2.1 |
|
cmake 3.26.4 |
|
comm 0.1.3 |
|
confection 0.1.0 |
|
contourpy 1.1.0 |
|
cubinlinker 0.3.0+2.g155b525 |
|
cuda-python 12.1.0rc5+1.g8659927 |
|
cudf 23.6.0 |
|
cugraph 23.6.0 |
|
cugraph-dgl 23.6.0 |
|
cugraph-service-client 23.6.0 |
|
cugraph-service-server 23.6.0 |
|
cuml 23.6.0 |
|
cupy-cuda12x 12.1.0 |
|
cut-cross-entropy 24.11.4 |
|
cycler 0.11.0 |
|
cymem 2.0.7 |
|
Cython 0.29.36 |
|
dask 2023.3.2 |
|
dask-cuda 23.6.0 |
|
dask-cudf 23.6.0 |
|
dataclasses-json 0.6.7 |
|
datasets 3.2.0 |
|
debugpy 1.6.7 |
|
decorator 5.1.1 |
|
deepspeed 0.15.4 |
|
defusedxml 0.7.1 |
|
Deprecated 1.2.15 |
|
dill 0.3.8 |
|
dirtyjson 1.0.8 |
|
distributed 2023.3.2.1 |
|
distro 1.9.0 |
|
dm-tree 0.1.8 |
|
docker-pycreds 0.4.0 |
|
docstring_parser 0.16 |
|
einops 0.6.1 |
|
exceptiongroup 1.1.2 |
|
execnet 1.9.0 |
|
executing 1.2.0 |
|
expecttest 0.1.3 |
|
fastjsonschema 2.17.1 |
|
fastrlock 0.8.1 |
|
filelock 3.12.2 |
|
filetype 1.2.0 |
|
fire 0.7.0 |
|
flash-attn 1.0.7 |
|
fonttools 4.40.0 |
|
frozenlist 1.3.3 |
|
fsspec 2023.6.0 |
|
gast 0.5.4 |
|
gitdb 4.0.11 |
|
GitPython 3.1.43 |
|
google-auth 2.21.0 |
|
google-auth-oauthlib 0.4.6 |
|
graphsurgeon 0.4.6 |
|
greenlet 3.1.1 |
|
grpcio 1.56.0 |
|
h11 0.14.0 |
|
h5py 3.11.0 |
|
hdf5plugin 4.4.0 |
|
hf_transfer 0.1.8 |
|
hjson 3.1.0 |
|
httpcore 1.0.7 |
|
httpx 0.27.2 |
|
huggingface-hub 0.26.2 |
|
hydra-core 1.3.2 |
|
hypothesis 5.35.1 |
|
idna 3.4 |
|
imageio 2.34.1 |
|
importlib-metadata 6.7.0 |
|
iniconfig 2.0.0 |
|
intel-openmp 2021.4.0 |
|
ipykernel 6.24.0 |
|
ipython 8.14.0 |
|
ipython-genutils 0.2.0 |
|
jedi 0.18.2 |
|
Jinja2 3.1.2 |
|
jiter 0.7.1 |
|
joblib 1.3.0 |
|
json5 0.9.14 |
|
jsonpatch 1.33 |
|
jsonpointer 3.0.0 |
|
jsonschema 4.18.0 |
|
jsonschema-specifications 2023.6.1 |
|
jupyter_client 8.3.0 |
|
jupyter_core 5.3.1 |
|
jupyter-tensorboard 0.2.0 |
|
jupyterlab 2.2.9 |
|
jupyterlab-pygments 0.2.2 |
|
jupyterlab-server 1.2.0 |
|
jupytext 1.14.7 |
|
kiwisolver 1.4.4 |
|
langchain 0.3.8 |
|
langchain-core 0.3.21 |
|
langchain-text-splitters 0.3.2 |
|
langcodes 3.3.0 |
|
langsmith 0.1.145 |
|
lazy_loader 0.4 |
|
librosa 0.9.2 |
|
lightning-utilities 0.11.9 |
|
llama-cloud 0.1.5 |
|
llama-index 0.12.1 |
|
llama-index-agent-openai 0.4.0 |
|
llama-index-cli 0.4.0 |
|
llama-index-core 0.12.1 |
|
llama-index-embeddings-openai 0.3.0 |
|
llama-index-indices-managed-llama-cloud 0.6.2 |
|
llama-index-legacy 0.9.48.post4 |
|
llama-index-llms-openai 0.3.1 |
|
llama-index-multi-modal-llms-openai 0.3.0 |
|
llama-index-program-openai 0.3.0 |
|
llama-index-question-gen-openai 0.3.0 |
|
llama-index-readers-file 0.4.0 |
|
llama-index-readers-llama-parse 0.4.0 |
|
llama-parse 0.5.15 |
|
llvmlite 0.40.1 |
|
locket 1.0.0 |
|
Markdown 3.4.3 |
|
markdown-it-py 3.0.0 |
|
MarkupSafe 2.1.3 |
|
marshmallow 3.23.1 |
|
matplotlib 3.7.2 |
|
matplotlib-inline 0.1.6 |
|
mdit-py-plugins 0.4.0 |
|
mdurl 0.1.2 |
|
mistune 3.0.1 |
|
mkl 2021.1.1 |
|
mkl-devel 2021.1.1 |
|
mkl-include 2021.1.1 |
|
mne 1.6.0 |
|
mock 5.0.2 |
|
mpmath 1.3.0 |
|
msgpack 1.0.5 |
|
multidict 6.0.4 |
|
multiprocess 0.70.16 |
|
murmurhash 1.0.9 |
|
mypy-extensions 1.0.0 |
|
nbclient 0.8.0 |
|
nbconvert 7.6.0 |
|
nbformat 5.9.0 |
|
nest-asyncio 1.6.0 |
|
networkx 3.4.2 |
|
ninja 1.11.1 |
|
nltk 3.9.1 |
|
notebook 6.4.10 |
|
numba 0.57.1+1.gf851d279c |
|
numpy 1.26.4 |
|
nvidia-cublas-cu12 12.4.5.8 |
|
nvidia-cuda-cupti-cu12 12.4.127 |
|
nvidia-cuda-nvrtc-cu12 12.4.127 |
|
nvidia-cuda-runtime-cu12 12.4.127 |
|
nvidia-cudnn-cu12 9.1.0.70 |
|
nvidia-cufft-cu12 11.2.1.3 |
|
nvidia-curand-cu12 10.3.5.147 |
|
nvidia-cusolver-cu12 11.6.1.9 |
|
nvidia-cusparse-cu12 12.3.1.170 |
|
nvidia-dali-cuda120 1.27.0 |
|
nvidia-nccl-cu12 2.21.5 |
|
nvidia-nvjitlink-cu12 12.4.127 |
|
nvidia-nvtx-cu12 12.4.127 |
|
nvidia-pyindex 1.0.9 |
|
nvtx 0.2.5 |
|
oauthlib 3.2.2 |
|
omegaconf 2.3.0 |
|
onnx 1.14.0 |
|
openai 1.55.0 |
|
opencv 4.7.0 |
|
orjson 3.10.12 |
|
outcome 1.3.0.post0 |
|
packaging 24.2 |
|
pandas 1.5.2 |
|
pandocfilters 1.5.0 |
|
parso 0.8.3 |
|
partd 1.4.0 |
|
pathy 0.10.2 |
|
peft 0.14.0 |
|
pexpect 4.8.0 |
|
pickleshare 0.7.5 |
|
Pillow 9.2.0 |
|
pip 23.1.2 |
|
platformdirs 3.8.0 |
|
pluggy 1.2.0 |
|
ply 3.11 |
|
polygraphy 0.47.1 |
|
pooch 1.7.0 |
|
preshed 3.0.8 |
|
prettytable 3.8.0 |
|
prometheus-client 0.17.0 |
|
prompt-toolkit 3.0.39 |
|
propcache 0.2.0 |
|
protobuf 3.20.3 |
|
psutil 5.9.4 |
|
ptxcompiler 0.8.1+1.gb323413 |
|
ptyprocess 0.7.0 |
|
pure-eval 0.2.2 |
|
py-cpuinfo 9.0.0 |
|
pyarrow 18.0.0 |
|
pyasn1 0.5.0 |
|
pyasn1-modules 0.3.0 |
|
pybind11 2.10.4 |
|
pycocotools 2.0+nv0.7.3 |
|
pycparser 2.21 |
|
pydantic 2.9.2 |
|
pydantic_core 2.23.4 |
|
Pygments 2.15.1 |
|
pylibcugraph 23.6.0 |
|
pylibcugraphops 23.6.0 |
|
pylibraft 23.6.0 |
|
pynvml 11.4.1 |
|
pyparsing 3.0.9 |
|
pypdf 5.1.0 |
|
PySocks 1.7.1 |
|
pytest 7.4.0 |
|
pytest-flakefinder 1.1.0 |
|
pytest-rerunfailures 12.0 |
|
pytest-shard 0.1.2 |
|
pytest-xdist 3.3.1 |
|
python-dateutil 2.8.2 |
|
python-hostlist 1.23.0 |
|
pytorch-quantization 2.1.2 |
|
pytz 2023.3 |
|
PyYAML 6.0.2 |
|
pyzmq 25.1.0 |
|
raft-dask 23.6.0 |
|
referencing 0.29.1 |
|
regex 2023.6.3 |
|
requests 2.32.3 |
|
requests-oauthlib 1.3.1 |
|
requests-toolbelt 1.0.0 |
|
resampy 0.4.2 |
|
rich 13.9.4 |
|
rmm 23.6.0 |
|
rpds-py 0.8.8 |
|
rsa 4.9 |
|
safetensors 0.4.5 |
|
scikit-learn 1.2.0 |
|
scipy 1.11.0 |
|
selenium 4.26.1 |
|
Send2Trash 1.8.2 |
|
sentence-transformers 3.3.1 |
|
sentencepiece 0.2.0 |
|
sentry-sdk 2.19.0 |
|
setproctitle 1.3.4 |
|
setuptools 68.0.0 |
|
shtab 1.7.1 |
|
six 1.16.0 |
|
smart-open 6.3.0 |
|
smmap 5.0.1 |
|
sniffio 1.3.1 |
|
sortedcontainers 2.4.0 |
|
soundfile 0.12.1 |
|
soupsieve 2.4.1 |
|
spacy 3.5.4 |
|
spacy-legacy 3.0.12 |
|
spacy-loggers 1.0.4 |
|
sphinx-glpi-theme 0.3 |
|
SQLAlchemy 2.0.36 |
|
srsly 2.4.6 |
|
stack-data 0.6.2 |
|
striprtf 0.0.26 |
|
sympy 1.13.1 |
|
tabulate 0.9.0 |
|
tbb 2021.9.0 |
|
tblib 2.0.0 |
|
tenacity 8.5.0 |
|
tensorboard 2.9.0 |
|
tensorboard-data-server 0.6.1 |
|
tensorboard-plugin-wit 1.8.1 |
|
tensorrt 8.6.1 |
|
termcolor 2.5.0 |
|
terminado 0.17.1 |
|
thinc 8.1.10 |
|
threadpoolctl 3.1.0 |
|
thriftpy2 0.4.16 |
|
tiktoken 0.8.0 |
|
tinycss2 1.2.1 |
|
tokenizers 0.21.0 |
|
toml 0.10.2 |
|
tomli 2.0.1 |
|
toolz 0.12.0 |
|
torch 2.5.1 |
|
torch-tensorrt 1.5.0.dev0 |
|
torchaudio 2.5.1 |
|
torchdata 0.7.0a0 |
|
torchmetrics 1.6.0 |
|
torchtext 0.16.0a0 |
|
torchvision 0.20.1 |
|
tornado 6.3.2 |
|
tqdm 4.67.1 |
|
traitlets 5.9.0 |
|
transformer-engine 0.10.0+96ed6fc |
|
transformers 4.47.0 |
|
treelite 3.2.0 |
|
treelite-runtime 3.2.0 |
|
trio 0.27.0 |
|
trio-websocket 0.11.1 |
|
triton 3.1.0 |
|
trl 0.12.1 |
|
typeguard 4.4.1 |
|
typer 0.9.0 |
|
types-dataclasses 0.6.6 |
|
typing_extensions 4.12.2 |
|
typing-inspect 0.9.0 |
|
tyro 0.9.2 |
|
ucx-py 0.32.0 |
|
uff 0.6.9 |
|
unsloth 2024.11.9 |
|
unsloth_zoo 2024.11.7 |
|
urllib3 1.26.16 |
|
wandb 0.18.7 |
|
wasabi 1.1.2 |
|
wcwidth 0.2.6 |
|
webencodings 0.5.1 |
|
websocket-client 1.8.0 |
|
Werkzeug 2.3.6 |
|
wheel 0.45.1 |
|
wrapt 1.17.0 |
|
wsproto 1.2.0 |
|
xdoctest 1.0.2 |
|
xformers 0.0.28.post3 |
|
xgboost 1.7.5 |
|
xxhash 3.5.0 |
|
yarl 1.18.0 |
|
zict 3.0.0 |
|
zipp 3.15.0 |
|
``` |
|
|
|
下記のコードで実行をしました。 |
|
``` |
|
from transformers import ( |
|
AutoModelForCausalLM, |
|
AutoTokenizer, |
|
BitsAndBytesConfig, |
|
) |
|
from peft import PeftModel |
|
import torch |
|
from tqdm import tqdm |
|
import json |
|
|
|
# ベースとなるモデルと学習したLoRAのアダプタ。 |
|
model_id = "llm-jp/llm-jp-3-13b" |
|
adapter_id = "kevineen/llm-jp-3-13b-finetune" # Hugging FaceのIDを指定。 |
|
|
|
# QLoRA config |
|
bnb_config = BitsAndBytesConfig( |
|
load_in_4bit=True, |
|
bnb_4bit_quant_type="nf4", |
|
bnb_4bit_compute_dtype=torch.bfloat16, |
|
) |
|
|
|
# Load model |
|
model = AutoModelForCausalLM.from_pretrained( |
|
model_id, |
|
quantization_config=bnb_config, |
|
device_map="auto", |
|
) |
|
|
|
# Load tokenizer |
|
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True) |
|
|
|
# 元のモデルにLoRAのアダプタを統合。 |
|
model = PeftModel.from_pretrained(model, adapter_id) |
|
|
|
# 出力したいデータセットの読み込み。 |
|
# omnicampusの開発環境では、左にタスクのjsonlをドラッグアンドドロップしてから実行。 |
|
## 対象のデータセットを用意してください(※ elyza-tasks-100-TV_0.jsonl |
|
|
|
datasets = [] |
|
with open("./elyza-tasks-100-TV_0.jsonl", "r") as f: |
|
item = "" |
|
for line in f: |
|
line = line.strip() |
|
item += line |
|
if item.endswith("}"): |
|
datasets.append(json.loads(item)) |
|
item = "" |
|
|
|
# llmjp |
|
results = [] |
|
for data in tqdm(datasets): |
|
|
|
input = data["input"] |
|
|
|
prompt = f"""### 指示 |
|
{input} |
|
### 回答 |
|
""" |
|
|
|
tokenized_input = tokenizer.encode(prompt, add_special_tokens=False, return_tensors="pt").to(model.device) |
|
attention_mask = torch.ones_like(tokenized_input) |
|
with torch.no_grad(): |
|
outputs = model.generate( |
|
tokenized_input, |
|
attention_mask=attention_mask, |
|
max_new_tokens=1024, |
|
do_sample=False, |
|
repetition_penalty=1.2, |
|
pad_token_id=tokenizer.eos_token_id |
|
)[0] |
|
output = tokenizer.decode(outputs[tokenized_input.size(1):], skip_special_tokens=True) |
|
|
|
results.append({"task_id": data["task_id"], "input": input, "output": output}) |
|
|
|
import re |
|
jsonl_id = re.sub(".*/", "", adapter_id) |
|
with open(f"./{jsonl_id}-outputs.jsonl", 'w', encoding='utf-8') as f: |
|
for result in results: |
|
json.dump(result, f, ensure_ascii=False) # ensure_ascii=False for handling non-ASCII characters |
|
f.write('\n') |
|
``` |