PyCodeT5 / generation_fast.py
S-Dreamer's picture
Update generation_fast.py
c59a42f verified
import torch
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
# Load model and tokenizer
model_name = "your_model_repo"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
# Ensure special tokens and preprocessing settings are applied
if tokenizer.special_tokens_map is None:
tokenizer.special_tokens_map = {
"bos_token": "<s>",
"eos_token": "</s>",
"unk_token": "<unk>",
"sep_token": "</s>",
"pad_token": "<pad>",
"cls_token": "<s>",
"mask_token": "<mask>"
}
tokenizer.save_pretrained(model_name)
preprocessor_config = {
"do_lower_case": False,
"max_length": 128,
"truncation": True,
"padding": "max_length"
}
# Define a function for text generation
def generate_code(prompt, max_length=128, temperature=0.7, top_p=0.9):
inputs = tokenizer(prompt, return_tensors="pt", truncation=True, padding=True, max_length=preprocessor_config["max_length"])
with torch.no_grad():
outputs = model.generate(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
max_length=max_length,
temperature=temperature,
top_p=top_p,
do_sample=True
)
return tokenizer.decode(outputs[0], skip_special_tokens=True)
# Example usage
if __name__ == "__main__":
prompt = "def quicksort(arr):"
generated_code = generate_code(prompt)
print("Generated Code:\n", generated_code)