Spaces:
Sleeping
Sleeping
| import pandas as pd | |
| import torch | |
| from transformers import ( | |
| AutoModelForCausalLM, | |
| AutoTokenizer, | |
| TrainingArguments, | |
| Trainer, | |
| DataCollatorForLanguageModeling | |
| ) | |
| from datasets import Dataset | |
| import json | |
| from pathlib import Path | |
| class BankingModelTrainer: | |
| def __init__( | |
| self, | |
| base_model_name="meta-llama/Llama-2-13b-chat-hf", | |
| output_dir="./fine_tuned_model", | |
| max_length=512 | |
| ): | |
| self.base_model_name = base_model_name | |
| self.output_dir = Path(output_dir) | |
| self.max_length = max_length | |
| self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
| # تنظیمات مدل Llama-2 | |
| model_config = { | |
| "device_map": "auto", | |
| "torch_dtype": torch.bfloat16, | |
| "low_cpu_mem_usage": True, | |
| "max_memory": {0: "10GB"}, | |
| "load_in_8bit": True | |
| } | |
| # تنظیمات اولیه مدل و توکنایزر | |
| self.tokenizer = AutoTokenizer.from_pretrained(base_model_name) | |
| self.model = AutoModelForCausalLM.from_pretrained( | |
| base_model_name, | |
| **model_config | |
| ) | |
| def prepare_data(self, data_path): | |
| # خواندن دیتا از فایل | |
| if data_path.endswith('.csv'): | |
| df = pd.read_csv(data_path) | |
| elif data_path.endswith('.json'): | |
| with open(data_path, 'r', encoding='utf-8') as f: | |
| data = json.load(f) | |
| df = pd.DataFrame(data) | |
| else: | |
| raise ValueError("فرمت فایل باید CSV یا JSON باشد") | |
| # پردازش و آمادهسازی دیتا | |
| def prepare_examples(examples): | |
| conversations = [] | |
| for q, a in zip(examples['question'], examples['answer']): | |
| # فرمت Llama-2 برای مکالمه | |
| conv = f"[INST] {q} [/INST] {a}" | |
| conversations.append(conv) | |
| # توکنایز کردن با تنظیمات Llama-2 | |
| encodings = self.tokenizer( | |
| conversations, | |
| truncation=True, | |
| padding=True, | |
| max_length=self.max_length, | |
| return_tensors="pt" | |
| ) | |
| return encodings | |
| dataset = Dataset.from_pandas(df) | |
| tokenized_dataset = dataset.map( | |
| prepare_examples, | |
| batched=True, | |
| remove_columns=dataset.column_names | |
| ) | |
| return tokenized_dataset | |
| def train(self, dataset, epochs=3, batch_size=4): | |
| training_args = TrainingArguments( | |
| output_dir=str(self.output_dir), | |
| num_train_epochs=epochs, | |
| per_device_train_batch_size=batch_size, | |
| gradient_accumulation_steps=4, | |
| save_steps=500, | |
| logging_steps=100, | |
| learning_rate=2e-5, # کاهش نرخ یادگیری برای Llama-2 | |
| warmup_steps=100, | |
| fp16=True, # فعال کردن fp16 برای Llama-2 | |
| save_total_limit=2, | |
| logging_dir=str(self.output_dir / "logs"), | |
| gradient_checkpointing=True # فعال کردن gradient checkpointing | |
| ) | |
| data_collator = DataCollatorForLanguageModeling( | |
| tokenizer=self.tokenizer, | |
| mlm=False | |
| ) | |
| trainer = Trainer( | |
| model=self.model, | |
| args=training_args, | |
| train_dataset=dataset, | |
| data_collator=data_collator | |
| ) | |
| trainer.train() | |
| self.model.save_pretrained(self.output_dir) | |
| self.tokenizer.save_pretrained(self.output_dir) | |
| def generate_response(self, prompt): | |
| # فرمت Llama-2 برای پرامپت | |
| formatted_prompt = f"[INST] {prompt} [/INST]" | |
| inputs = self.tokenizer.encode( | |
| formatted_prompt, | |
| return_tensors="pt" | |
| ).to(self.device) | |
| outputs = self.model.generate( | |
| inputs, | |
| max_length=self.max_length, | |
| num_return_sequences=1, | |
| temperature=0.7, | |
| top_p=0.9, | |
| do_sample=True, | |
| pad_token_id=self.tokenizer.eos_token_id, | |
| repetition_penalty=1.2 # اضافه کردن جریمه تکرار | |
| ) | |
| response = self.tokenizer.decode( | |
| outputs[0], | |
| skip_special_tokens=True | |
| ) | |
| # حذف پرامپت از پاسخ | |
| response = response.replace(formatted_prompt, "").strip() | |
| return response | |
| if __name__ == "__main__": | |
| trainer = BankingModelTrainer() | |
| dataset = trainer.prepare_data("banking_qa.json") | |
| trainer.train(dataset) | |
| response = trainer.generate_response("شرایط وام مسکن چیست؟") | |
| print(response) | |