lexiconium's picture
Update README.md
6c68e34
|
raw
history blame
1.24 kB

Example

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM 

tokenizer = AutoTokenizer.from_pretrained(
    "CheonggyeMountain-Sherpa/kogpt-trinity-punct-wrapper",
    revision="punct_wrapper-related_words-overfit",  # or punct_wrapper-related_words-minevalloss
    bos_token="<s>",
    eos_token="</s>",
    unk_token="<unk>",
    pad_token="<pad>",
    mask_token="<mask>",
)
model = AutoModelForCausalLM.from_pretrained(
    "CheonggyeMountain-Sherpa/kogpt-trinity-punct-wrapper",
    revision="punct_wrapper-related_words-overfit",  # or punct_wrapper-related_words-minevalloss
    pad_token_id=tokenizer.eos_token_id,
).to(device="cuda")
model.eval()

prompt = "석양이 보이는 경치"
wrapped_prompt = f"@{prompt}@"
with torch.no_grad():
    tokens = tokenizer.encode(wrapped_prompt, return_tensors="pt").to(device="cuda")
    gen_tokens = model.generate(
        tokens,
        max_length=64,
        repetition_penalty=2.0,
        pad_token_id=tokenizer.pad_token_id,
        eos_token_id=tokenizer.eos_token_id,
        bos_token_id=tokenizer.bos_token_id,
        top_k=16,
        top_p=0.8,
    )
    generated = tokenizer.decode(gen_tokens[0][len(tokens[0]):])
 
print(generated)