t5-base-korean-chit-chat

This model is a fine-tuning of paust/pko-t5-base model using AIHUB "ν•œκ΅­μ–΄ SNS". This model infers the next conversation by using the conversation used on social media..

이 λͺ¨λΈμ€ paust/pko-t5-large model을 AIHUB "ν•œκ΅­μ–΄ SNS"λ₯Ό μ΄μš©ν•˜μ—¬ fine tunning ν•œ κ²ƒμž…λ‹ˆλ‹€. 이 λͺ¨λΈμ€ SNSμƒμ—μ„œ μ‚¬μš©λ˜λŠ” λŒ€ν™”λ₯Ό μ΄μš©ν•˜μ—¬ λ‹€μŒ λŒ€ν™”λ₯Ό μΆ”λ‘  ν•©λ‹ˆλ‹€.

Usage


from transformers import AutoModelForSeq2SeqLM, DataCollatorForSeq2Seq, Seq2SeqTrainingArguments, Seq2SeqTrainer, MT5ForConditionalGeneration
from transformers import AutoTokenizer, T5TokenizerFast
import nltk
nltk.download('punkt')


model_dir = f"lcw99/t5-base-korean-chit-chat"

max_input_length = 1024

text = """
A: μ‡Όν•‘ν•˜λŸ¬ 갈까? B: 응 μ’‹μ•„. A: μ–Έμ œ 갈까? B:
"""

inputs = [text]

tokenizer = AutoTokenizer.from_pretrained(model_dir)
model = AutoModelForSeq2SeqLM.from_pretrained(model_dir)

inputs = tokenizer(inputs, max_length=max_input_length, truncation=True, return_tensors="pt")
output = model.generate(**inputs, num_beams=3, do_sample=True, min_length=20, max_length=500, num_return_sequences=3)
for i in range(3):
    #print(output[i])
    print("---", i)
    decoded_output = tokenizer.decode(output[i], skip_special_tokens=True)
    predicted_title = nltk.sent_tokenize(decoded_output)
    #print(decoded_output)
    print(predicted_title)

import torch

chat_history = []
# Let's chat for 5 lines
for step in range(100):
    print("")
    user_input = input(">> User: ")
    chat_history.append("A: " + user_input)
    while len(chat_history) > 5:
        chat_history.pop(0)
    hist = ""
    for chat in chat_history:
        hist += "\n" + chat
    hist += "\nB: "
    new_user_input_ids = tokenizer.encode(hist, return_tensors='pt')

    bot_input_ids = new_user_input_ids

    # generated a response while limiting the total chat history to 1000 tokens, 
    chat_history_ids = model.generate(
        bot_input_ids, max_length=200,
        pad_token_id=tokenizer.eos_token_id,  
        do_sample=True, 
        #top_k=100, 
        #top_p=0.7,
        #temperature = 0.1
    )

    bot_text = tokenizer.decode(chat_history_ids[0], skip_special_tokens=True).replace("#@이름#", "OOO")
    bot_text = bot_text.replace("\n", " / ")
    chat_history.append("B: " + bot_text)
    
    # pretty print last ouput tokens from bot
    print("Bot: {}".format(bot_text))    

Framework versions

  • Transformers 4.22.1
  • TensorFlow 2.10.0
  • Datasets 2.5.1
  • Tokenizers 0.12.1
Downloads last month
20
Safetensors
Model size
276M params
Tensor type
F32
Β·
Inference Examples
This model does not have enough activity to be deployed to Inference API (serverless) yet. Increase its social visibility and check back later, or deploy to Inference Endpoints (dedicated) instead.

Spaces using lcw99/t5-base-korean-chit-chat 4