|
import torch |
|
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer |
|
|
|
|
|
model_name = "your_model_repo" |
|
tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
model = AutoModelForSeq2SeqLM.from_pretrained(model_name) |
|
|
|
|
|
if tokenizer.special_tokens_map is None: |
|
tokenizer.special_tokens_map = { |
|
"bos_token": "<s>", |
|
"eos_token": "</s>", |
|
"unk_token": "<unk>", |
|
"sep_token": "</s>", |
|
"pad_token": "<pad>", |
|
"cls_token": "<s>", |
|
"mask_token": "<mask>" |
|
} |
|
tokenizer.save_pretrained(model_name) |
|
|
|
preprocessor_config = { |
|
"do_lower_case": False, |
|
"max_length": 128, |
|
"truncation": True, |
|
"padding": "max_length" |
|
} |
|
|
|
|
|
def generate_code(prompt, max_length=128, temperature=0.7, top_p=0.9): |
|
inputs = tokenizer(prompt, return_tensors="pt", truncation=True, padding=True, max_length=preprocessor_config["max_length"]) |
|
|
|
with torch.no_grad(): |
|
outputs = model.generate( |
|
input_ids=inputs["input_ids"], |
|
attention_mask=inputs["attention_mask"], |
|
max_length=max_length, |
|
temperature=temperature, |
|
top_p=top_p, |
|
do_sample=True |
|
) |
|
|
|
return tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
|
|
|
|
if __name__ == "__main__": |
|
prompt = "def quicksort(arr):" |
|
generated_code = generate_code(prompt) |
|
print("Generated Code:\n", generated_code) |