gpt-tools / generate.py
AliMc2021's picture
add application files
d9b8e9c
raw
history blame
3 kB
import torch
seed = 0
def generate_text(model_data, input_text, max_new_token):
"""
Generate text using the given model and tokenizer.
"""
if "pipeline" in model_data:
# اگر مدل از pipeline پشتیبانی می‌کند
model_pipeline = model_data["pipeline"]
generated_text = model_pipeline(
input_text,
max_length=max_new_token,
do_sample=False, # غیرفعال کردن نمونه‌گیری (حالت حریصانه)
truncation=True # فعال کردن truncation
)[0]["generated_text"]
return generated_text
else:
# روش قدیمی برای مدل‌هایی که از pipeline پشتیبانی نمی‌کنند
model = model_data["model"]
tokenizer = model_data["tokenizer"]
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
encodings = tokenizer(
input_text,
return_tensors="pt",
padding=True,
truncation=True, # فعال کردن truncation
max_length=512
)
input_ids = encodings.input_ids
attention_mask = encodings.attention_mask
outputs = model.generate(
input_ids=input_ids,
attention_mask=attention_mask,
max_new_tokens=max_new_token,
do_sample=False, # غیرفعال کردن نمونه‌گیری (حالت حریصانه)
pad_token_id=tokenizer.eos_token_id,
repetition_penalty=1.2,
no_repeat_ngram_size=3,
)
return tokenizer.decode(outputs[0], skip_special_tokens=True)
def generate_code(model_data, prompt, max_new_tokens):
"""
Generate code based on the provided prompt using a code-specific model.
"""
model = model_data["model"]
tokenizer = model_data["tokenizer"]
# تنظیم seed برای خروجی ثابت
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
# توکنایز کردن ورودی
input_ids = tokenizer.encode(prompt, return_tensors="pt")
# ایجاد attention mask
attention_mask = torch.ones(input_ids.shape, device=input_ids.device) # ایجاد یک ماسک توجه برای ورودی‌ها
# تولید کد
outputs = model.generate(
input_ids=input_ids,
attention_mask=attention_mask, # ارسال attention mask
max_new_tokens=max_new_tokens,
do_sample=False,
pad_token_id=tokenizer.eos_token_id, # تنظیم شناسه توکن پایان به عنوان پرکننده
repetition_penalty=1.2, # جلوگیری از تکرار
no_repeat_ngram_size=3, # جلوگیری از تکرار n-gram
)
return tokenizer.decode(outputs[0], skip_special_tokens=True)