kevin-yang
Add application file
4f76eaa
raw
history blame
2.11 kB
import torch
import streamlit as st
from transformers import GPT2Tokenizer, GPT2LMHeadModel, PreTrainedTokenizerFast
import numpy as np
model = GPT2LMHeadModel.from_pretrained("jason9693/soongsil-univ-gpt-v1")
tokenizer = PreTrainedTokenizerFast.from_pretrained("jason9693/soongsil-univ-gpt-v1")
category_map = {
"์ˆญ์‹ค๋Œ€ ์—ํƒ€": "<unused5>",
"๋ชจ๋‘์˜ ์—ฐ์• ": "<unused3>",
"๋Œ€ํ•™์ƒ ์žก๋‹ด๋ฐฉ": "<unused4>"
}
st.markdown("""# University Community KoGPT2 : ์ˆญ์‹ค๋Œ€ ์—๋ธŒ๋ฆฌํƒ€์ž„๋ด‡
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1p6DIxsesi3eJNPwFwvMw0MeM5LkSGoPW?usp=sharing) [![contributions welcome](https://img.shields.io/badge/contributions-welcome-brightgreen.svg?style=flat)](https://github.com/jason9693/UCK-GPT2/issues) ![GitHub](https://img.shields.io/github/license/jason9693/UCK-GPT2)
## ๋Œ€ํ•™ ์ปค๋ฎค๋‹ˆํ‹ฐ ๊ฒŒ์‹œ๊ธ€ ์ƒ์„ฑ๊ธฐ
SKT-AI์—์„œ ๊ณต๊ฐœํ•œ [KoGPT2](https://github.com/SKT-AI/KoGPT2) ๋ชจ๋ธ์„ ํŒŒ์ธํŠœ๋‹ํ•˜์—ฌ ๋Œ€ํ•™ ์ปค๋ฎค๋‹ˆํ‹ฐ ๊ฒŒ์‹œ๊ธ€์„ ์ƒ์„ฑํ•˜๋Š” ๋ชจ๋ธ์ž…๋‹ˆ๋‹ค. ์ด ์—๋ธŒ๋ฆฌํƒ€์ž„, ์บ ํผ์Šคํ”ฝ ๋ฐ์ดํ„ฐ 22๋งŒ๊ฐœ๋ฅผ ์ด์šฉํ•ด์„œ ํ•™์Šต์„ ์ง„ํ–‰ํ–ˆ์œผ๋ฉฐ, ํ•™์Šต์—๋Š” ๋Œ€๋žต **3์ผ**์ •๋„ ์†Œ์š”๋˜์—ˆ์Šต๋‹ˆ๋‹ค.
* [GPT ๋…ผ๋ฌธ ๋ฆฌ๋ทฐ ๋งํฌ](https://www.notion.so/Improve-Language-Understanding-by-Generative-Pre-Training-GPT-afb4b5ef6e984961ac022b700c152b6b)
## ์‹œ์—ฐํ•˜๊ธฐ
""")
seed = st.text_input("Seed", "์กฐ๋งŒ์‹ ๊ธฐ๋…๊ด€")
category = st.selectbox("Category", list(category_map.keys()))
go = st.button("Generate")
st.markdown("## ์ƒ์„ฑ ๊ฒฐ๊ณผ")
if go:
input_context = category_map[category] + seed
input_ids = tokenizer(input_context, return_tensors="pt")
outputs = model.generate(
input_ids=input_ids["input_ids"],
max_length=250,
num_return_sequences=1,
no_repeat_ngram_size=3,
repetition_penalty=2.0,
do_sample=True,
use_cache=True,
eos_token_id=tokenizer.eos_token_id
)
st.write(tokenizer.batch_decode(outputs, skip_special_tokens=True)[0].replace("<unused2>", "\n"))