This example utilizes the European AI Act regulation text as training data in three languages: English, Finnish, and Swedish. The dataset comprises 9,175 data points for training and 2,456 for evaluation.
Python libraries needed:
pip install -U transformers
pip install torch torchvision torchaudio
pip install 'accelerate>=0.26.0'
The training arguments used are as follows:
training_args = TrainingArguments(
per_device_train_batch_size=32,
gradient_accumulation_steps=32,
warmup_steps=20,
max_steps=400,
learning_rate=1.5e-5,
fp16=not is_bfloat16_supported(),
bf16=is_bfloat16_supported(),
logging_steps=1,
optim="adamw_8bit",
weight_decay=0.01,
lr_scheduler_type="cosine",
seed=3407,
output_dir=output_dir,
report_to="none",
eval_strategy="steps",
eval_steps=10,
load_best_model_at_end=True,
metric_for_best_model="eval_loss",
greater_is_better=False,
save_total_limit=2,
)
The prediction is made using the standard Gemma:
from transformers import AutoTokenizer, AutoModelForCausalLM
import transformers
import torch
model_id = "mlconvexai/gemma-2-2b-it-finetuned-EU-Act-v2"
dtype = torch.bfloat16
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
model_id,
device_map="auto",
torch_dtype=dtype,)
chat = [
{ "role": "user", "content": "Mikä on EU:n tekoälyasetus?" },
]
prompt = tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=True)
inputs = tokenizer.encode(prompt, add_special_tokens=False, return_tensors="pt")
outputs = model.generate(
input_ids=inputs.to(model.device),
max_new_tokens=1024,
repetition_penalty=1.1,
no_repeat_ngram_size=4,
)
print(tokenizer.decode(outputs[0]))
More detailed information about fine-tuning can be found on Medium.
Uploaded model
- Developed by: mlconvexai
- License: Gemma
- Finetuned from model : google/gemma-2-2b-it
This gemma2 model was trained 2x faster with Unsloth and Huggingface's TRL library.
- Downloads last month
- 4
Inference Providers
NEW
This model is not currently available via any of the supported third-party Inference Providers, and
the model is not deployed on the HF Inference API.