Spaces:
Runtime error
Runtime error
Galuh Sahid
commited on
Commit
•
a4af9d2
1
Parent(s):
bcdac5a
Init
Browse files- .gitignore +2 -0
- SessionState.py +107 -0
- app.py +134 -0
- huggingwayang.png +0 -0
- lid.176.ftz +3 -0
- prompts.py +36 -0
- requirements.txt +9 -0
.gitignore
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
gpt2-demo
|
2 |
+
__pycache__
|
SessionState.py
ADDED
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Hack to add per-session state to Streamlit.
|
2 |
+
Usage
|
3 |
+
-----
|
4 |
+
>>> import SessionState
|
5 |
+
>>>
|
6 |
+
>>> session_state = SessionState.get(user_name='', favorite_color='black')
|
7 |
+
>>> session_state.user_name
|
8 |
+
''
|
9 |
+
>>> session_state.user_name = 'Mary'
|
10 |
+
>>> session_state.favorite_color
|
11 |
+
'black'
|
12 |
+
Since you set user_name above, next time your script runs this will be the
|
13 |
+
result:
|
14 |
+
>>> session_state = get(user_name='', favorite_color='black')
|
15 |
+
>>> session_state.user_name
|
16 |
+
'Mary'
|
17 |
+
"""
|
18 |
+
try:
|
19 |
+
import streamlit.ReportThread as ReportThread
|
20 |
+
from streamlit.server.Server import Server
|
21 |
+
except Exception:
|
22 |
+
# Streamlit >= 0.65.0
|
23 |
+
import streamlit.report_thread as ReportThread
|
24 |
+
from streamlit.server.server import Server
|
25 |
+
|
26 |
+
|
27 |
+
class SessionState(object):
|
28 |
+
def __init__(self, **kwargs):
|
29 |
+
"""A new SessionState object.
|
30 |
+
Parameters
|
31 |
+
----------
|
32 |
+
**kwargs : any
|
33 |
+
Default values for the session state.
|
34 |
+
Example
|
35 |
+
-------
|
36 |
+
>>> session_state = SessionState(user_name='', favorite_color='black')
|
37 |
+
>>> session_state.user_name = 'Mary'
|
38 |
+
''
|
39 |
+
>>> session_state.favorite_color
|
40 |
+
'black'
|
41 |
+
"""
|
42 |
+
for key, val in kwargs.items():
|
43 |
+
setattr(self, key, val)
|
44 |
+
|
45 |
+
|
46 |
+
def get(**kwargs):
|
47 |
+
"""Gets a SessionState object for the current session.
|
48 |
+
Creates a new object if necessary.
|
49 |
+
Parameters
|
50 |
+
----------
|
51 |
+
**kwargs : any
|
52 |
+
Default values you want to add to the session state, if we're creating a
|
53 |
+
new one.
|
54 |
+
Example
|
55 |
+
-------
|
56 |
+
>>> session_state = get(user_name='', favorite_color='black')
|
57 |
+
>>> session_state.user_name
|
58 |
+
''
|
59 |
+
>>> session_state.user_name = 'Mary'
|
60 |
+
>>> session_state.favorite_color
|
61 |
+
'black'
|
62 |
+
Since you set user_name above, next time your script runs this will be the
|
63 |
+
result:
|
64 |
+
>>> session_state = get(user_name='', favorite_color='black')
|
65 |
+
>>> session_state.user_name
|
66 |
+
'Mary'
|
67 |
+
"""
|
68 |
+
# Hack to get the session object from Streamlit.
|
69 |
+
|
70 |
+
ctx = ReportThread.get_report_ctx()
|
71 |
+
|
72 |
+
this_session = None
|
73 |
+
|
74 |
+
current_server = Server.get_current()
|
75 |
+
if hasattr(current_server, '_session_infos'):
|
76 |
+
# Streamlit < 0.56
|
77 |
+
session_infos = Server.get_current()._session_infos.values()
|
78 |
+
else:
|
79 |
+
session_infos = Server.get_current()._session_info_by_id.values()
|
80 |
+
|
81 |
+
for session_info in session_infos:
|
82 |
+
s = session_info.session
|
83 |
+
if (
|
84 |
+
# Streamlit < 0.54.0
|
85 |
+
(hasattr(s, '_main_dg') and s._main_dg == ctx.main_dg)
|
86 |
+
or
|
87 |
+
# Streamlit >= 0.54.0
|
88 |
+
(not hasattr(s, '_main_dg') and s.enqueue == ctx.enqueue)
|
89 |
+
or
|
90 |
+
# Streamlit >= 0.65.2
|
91 |
+
(not hasattr(s, '_main_dg') and s._uploaded_file_mgr == ctx.uploaded_file_mgr)
|
92 |
+
):
|
93 |
+
this_session = s
|
94 |
+
|
95 |
+
if this_session is None:
|
96 |
+
raise RuntimeError(
|
97 |
+
"Oh noes. Couldn't get your Streamlit Session object. "
|
98 |
+
'Are you doing something fancy with threads?')
|
99 |
+
|
100 |
+
# Got the session object! Now let's attach some state into it.
|
101 |
+
|
102 |
+
if not hasattr(this_session, '_custom_session_state'):
|
103 |
+
this_session._custom_session_state = SessionState(**kwargs)
|
104 |
+
|
105 |
+
return this_session._custom_session_state
|
106 |
+
|
107 |
+
__all__ = ['get']
|
app.py
ADDED
@@ -0,0 +1,134 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import requests
|
3 |
+
from mtranslate import translate
|
4 |
+
from prompts import PROMPT_LIST
|
5 |
+
import streamlit as st
|
6 |
+
import random
|
7 |
+
from transformers import GPT2Tokenizer, GPT2LMHeadModel
|
8 |
+
import fasttext
|
9 |
+
import SessionState
|
10 |
+
|
11 |
+
LOGO = "huggingwayang.png"
|
12 |
+
|
13 |
+
MODELS = {
|
14 |
+
"GPT-2 Small": "flax-community/gpt2-small-indonesian",
|
15 |
+
"GPT-2 Medium": "flax-community/gpt2-medium-indonesian",
|
16 |
+
"GPT-2 Small Finetuned on Indonesian Journals": "Galuh/id-journal-gpt2"
|
17 |
+
}
|
18 |
+
|
19 |
+
headers = {}
|
20 |
+
|
21 |
+
@st.cache(show_spinner=False, persist=True)
|
22 |
+
def load_gpt(model_type, text):
|
23 |
+
print("Loading model...")
|
24 |
+
model = GPT2LMHeadModel.from_pretrained(MODELS[model_type])
|
25 |
+
tokenizer = GPT2Tokenizer.from_pretrained(MODELS[model_type])
|
26 |
+
|
27 |
+
return model, tokenizer
|
28 |
+
|
29 |
+
def get_image(text: str):
|
30 |
+
url = "https://wikisearch.uncool.ai/get_image/"
|
31 |
+
try:
|
32 |
+
payload = {
|
33 |
+
"text": text,
|
34 |
+
"image_width": 400
|
35 |
+
}
|
36 |
+
data = json.dumps(payload)
|
37 |
+
response = requests.request("POST", url, headers=headers, data=data)
|
38 |
+
print(response.content)
|
39 |
+
image = json.loads(response.content.decode("utf-8"))["url"]
|
40 |
+
except:
|
41 |
+
image = ""
|
42 |
+
return image
|
43 |
+
|
44 |
+
st.set_page_config(page_title="Indonesian GPT-2 Demo")
|
45 |
+
|
46 |
+
st.title("Indonesian GPT-2")
|
47 |
+
|
48 |
+
# ft_model = fasttext.load_model('lid.176.ftz')
|
49 |
+
|
50 |
+
# Sidebar
|
51 |
+
st.sidebar.image(LOGO)
|
52 |
+
st.sidebar.subheader("Configurable parameters")
|
53 |
+
|
54 |
+
max_len = st.sidebar.number_input(
|
55 |
+
"Maximum length",
|
56 |
+
value=100,
|
57 |
+
help="The maximum length of the sequence to be generated."
|
58 |
+
)
|
59 |
+
|
60 |
+
temp = st.sidebar.slider(
|
61 |
+
"Temperature",
|
62 |
+
value=1.0,
|
63 |
+
min_value=0.0,
|
64 |
+
max_value=100.0,
|
65 |
+
help="The value used to module the next token probabilities."
|
66 |
+
)
|
67 |
+
|
68 |
+
top_k = st.sidebar.number_input(
|
69 |
+
"Top k",
|
70 |
+
value=50,
|
71 |
+
help="The number of highest probability vocabulary tokens to keep for top-k-filtering."
|
72 |
+
)
|
73 |
+
|
74 |
+
top_p = st.sidebar.number_input(
|
75 |
+
"Top p",
|
76 |
+
value=1.0,
|
77 |
+
help=" If set to float < 1, only the most probable tokens with probabilities that add up to top_p or higher are kept for generation."
|
78 |
+
)
|
79 |
+
|
80 |
+
st.markdown(
|
81 |
+
"""
|
82 |
+
This demo uses the [small](https://huggingface.co/flax-community/gpt2-small-indonesian) and
|
83 |
+
[medium](https://huggingface.co/flax-community/gpt2-medium-indonesian) Indonesian GPT2 model
|
84 |
+
trained on the Indonesian [Oscar](https://huggingface.co/datasets/oscar), [MC4](https://huggingface.co/datasets/mc4)
|
85 |
+
and [Wikipedia](https://huggingface.co/datasets/wikipedia) dataset. We created it as part of the
|
86 |
+
[Huggingface JAX/Flax event](https://discuss.huggingface.co/t/open-to-the-community-community-week-using-jax-flax-for-nlp-cv/).
|
87 |
+
|
88 |
+
The demo supports "multi language" ;-), feel free to try a prompt on your language. We are also experimenting with
|
89 |
+
the sentence based image search using Wikipedia passages encoded with distillbert, and search the encoded sentence
|
90 |
+
in the encoded passages using Facebook's Faiss.
|
91 |
+
"""
|
92 |
+
)
|
93 |
+
|
94 |
+
model_name = st.selectbox('Model',(['GPT-2 Small', 'GPT-2 Medium', 'GPT-2 Small Finetuned on Indonesian Journals']))
|
95 |
+
|
96 |
+
if model_name in ["GPT-2 Small", "GPT-2 Medium"]:
|
97 |
+
prompt_group_name = "GPT-2"
|
98 |
+
elif model_name in ["GPT-2 Small Finetuned on Indonesian Journals"]:
|
99 |
+
prompt_group_name = "Indonesian Journals"
|
100 |
+
|
101 |
+
ALL_PROMPTS = list(PROMPT_LIST[prompt_group_name].keys())+["Custom"]
|
102 |
+
prompt = st.selectbox('Prompt', ALL_PROMPTS, index=len(ALL_PROMPTS)-1)
|
103 |
+
|
104 |
+
session_state = SessionState.get(prompt_box=None)
|
105 |
+
|
106 |
+
if prompt == "Custom":
|
107 |
+
prompt_box = "Enter your text here"
|
108 |
+
else:
|
109 |
+
prompt_box = random.choice(PROMPT_LIST[prompt_group_name][prompt])
|
110 |
+
|
111 |
+
session_state.prompt_box = prompt_box
|
112 |
+
|
113 |
+
text = st.text_area("Enter text", session_state.prompt_box)
|
114 |
+
|
115 |
+
if st.button("Run"):
|
116 |
+
with st.spinner(text="Getting results..."):
|
117 |
+
|
118 |
+
st.subheader("Result")
|
119 |
+
model, tokenizer = load_gpt(model_name, text)
|
120 |
+
|
121 |
+
input_ids = tokenizer.encode(text, return_tensors='pt')
|
122 |
+
output = model.generate(input_ids=input_ids,
|
123 |
+
max_length=max_len,
|
124 |
+
temperature=temp,
|
125 |
+
top_k=top_k,
|
126 |
+
top_p=top_p,
|
127 |
+
repetition_penalty=2.0)
|
128 |
+
|
129 |
+
text = tokenizer.decode(output[0],
|
130 |
+
skip_special_tokens=True)
|
131 |
+
st.write(text.replace("\n", " \n"))
|
132 |
+
|
133 |
+
st.text("Translation")
|
134 |
+
translation = translate(text, "en", "id")
|
huggingwayang.png
ADDED
lid.176.ftz
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:8f3472cfe8738a7b6099e8e999c3cbfae0dcd15696aac7d7738a8039db603e83
|
3 |
+
size 938013
|
prompts.py
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
PROMPT_LIST = {
|
2 |
+
"GPT-2": {
|
3 |
+
"Resep masakan (recipe)": [
|
4 |
+
"Berikut adalah cara memasak sate ayam: ",
|
5 |
+
"Langkah-langkah membuat kastengel: ",
|
6 |
+
"Berikut adalah bahan-bahan membuat nastar: "
|
7 |
+
],
|
8 |
+
"Puisi (poetry)": [
|
9 |
+
"Aku ingin jadi merpati\nTerbang di langit yang damai\nBernyanyi-nyanyi tentang masa depan\n",
|
10 |
+
"Terdiam aku satu persatu dengan tatapan binar\nSenyawa merasuk dalam sukma membuat lara\nKefanaan membentuk kelemahan"
|
11 |
+
],
|
12 |
+
"Cerpen (short story)": [
|
13 |
+
"Putri memakai sepatunya dengan malas. Kalau bisa, selama seminggu ini ia bolos sekolah saja. Namun, Mama pasti akan marah. Ulangan tengah semester telah selesai. Minggu ini, di sekolah sedang berlangsung pekan olahraga.",
|
14 |
+
"\"Wah, hari ini cerah sekali ya,\" ucap Budi ketika ia keluar rumah.",
|
15 |
+
"Sewindu sudah kita tak berjumpa, rinduku padamu sudah tak terkira."
|
16 |
+
],
|
17 |
+
"Sejarah (history)": [
|
18 |
+
"Mohammad Natsir adalah seorang ulama, politisi, dan pejuang kemerdekaan Indonesia.",
|
19 |
+
"Ir. H. Soekarno adalah Presiden pertama Republik Indonesia. Ia adalah seorang tokoh perjuangan yang memainkan peranan penting dalam memerdekakan bangsa Indonesia",
|
20 |
+
"Borobudur adalah sebuah candi Buddha yang terletak di sebelah barat laut Yogyakarta. Monumen ini merupakan model alam semesta dan dibangun sebagai tempat suci untuk memuliakan Buddha"
|
21 |
+
],
|
22 |
+
"English": [
|
23 |
+
"Deoxyribonucleic acid is a molecule composed of two polynucleotide chains that coil around each other",
|
24 |
+
"Javanese is the largest of the Austronesian languages in number of native speakers"
|
25 |
+
],
|
26 |
+
"German": [
|
27 |
+
"Eine Meerjungfrau, auch Seejungfrau oder Fischweib, ist ein weibliches Fabelwesen, ein Mischwesen aus Frauen- und Fischkörper",
|
28 |
+
"Der Mond ist der einzige natürliche Satellit der Erde"
|
29 |
+
]},
|
30 |
+
"Indonesian Journals": {
|
31 |
+
"Biologi (biology)": ["Tujuan penelitian ini untuk menentukan keanekaragaman Arthropoda pada lahan pertanian kacang", "Identifikasi spesies secara molekuler sangat diperlukan dalam mempelajari taksonomi", "Penelitian ini bertujuan untuk menentukan identitas invertebrata laut dari Perairan Papua dengan teknik DNA barcoding"],
|
32 |
+
"Psikologi (psychology)": ["Penelitian ini bertujuan untuk mengetahui perilaku wirausaha remaja yang diprediksi dari motivasi intrinsik", "Tujuan dari penelitian ini adalah untuk mendapatkan data empiris mengenai gambaran peta bakat mahasiswa Fakultas Psikologi Unjani"],
|
33 |
+
"Ekonomi (economics)": ["Faktor kepuasan dan kepercayaan konsumen merupakan dua faktor kunci dalam meningkatkan penetrasi e-commerce. Peneltiian yang dilakukan", "Penelitian ini bertujuan untuk menganalisis pola konsumsi pangan di Indonesia", "Model GTAP diimplementasikan untuk melihat dampak yang ditimbulkan pada PDB"],
|
34 |
+
"Teknologi Informasi (IT)": ["pembuatan aplikasi ini menggunakan pengembangan metode Waterfall dan dirancang mengguynakan Unified Modeling Language (UML) dengan bahasa pemrograman", "Berdasarkan masalah tersebut, maka penulis termotivasi untuk membangun Pengembangan Sistem Informasi Manajemen"]
|
35 |
+
}
|
36 |
+
}
|
requirements.txt
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
transformers
|
2 |
+
streamlit
|
3 |
+
requests==2.24.0
|
4 |
+
requests-toolbelt==0.9.1
|
5 |
+
mtranslate
|
6 |
+
-f https://download.pytorch.org/whl/torch_stable.html
|
7 |
+
torch==1.7.1+cpu; sys_platform == 'linux'
|
8 |
+
torch==1.7.1; sys_platform == 'darwin'
|
9 |
+
fasttext
|