Galuh Sahid commited on
Commit
a4af9d2
1 Parent(s): bcdac5a
Files changed (7) hide show
  1. .gitignore +2 -0
  2. SessionState.py +107 -0
  3. app.py +134 -0
  4. huggingwayang.png +0 -0
  5. lid.176.ftz +3 -0
  6. prompts.py +36 -0
  7. requirements.txt +9 -0
.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ gpt2-demo
2
+ __pycache__
SessionState.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Hack to add per-session state to Streamlit.
2
+ Usage
3
+ -----
4
+ >>> import SessionState
5
+ >>>
6
+ >>> session_state = SessionState.get(user_name='', favorite_color='black')
7
+ >>> session_state.user_name
8
+ ''
9
+ >>> session_state.user_name = 'Mary'
10
+ >>> session_state.favorite_color
11
+ 'black'
12
+ Since you set user_name above, next time your script runs this will be the
13
+ result:
14
+ >>> session_state = get(user_name='', favorite_color='black')
15
+ >>> session_state.user_name
16
+ 'Mary'
17
+ """
18
+ try:
19
+ import streamlit.ReportThread as ReportThread
20
+ from streamlit.server.Server import Server
21
+ except Exception:
22
+ # Streamlit >= 0.65.0
23
+ import streamlit.report_thread as ReportThread
24
+ from streamlit.server.server import Server
25
+
26
+
27
+ class SessionState(object):
28
+ def __init__(self, **kwargs):
29
+ """A new SessionState object.
30
+ Parameters
31
+ ----------
32
+ **kwargs : any
33
+ Default values for the session state.
34
+ Example
35
+ -------
36
+ >>> session_state = SessionState(user_name='', favorite_color='black')
37
+ >>> session_state.user_name = 'Mary'
38
+ ''
39
+ >>> session_state.favorite_color
40
+ 'black'
41
+ """
42
+ for key, val in kwargs.items():
43
+ setattr(self, key, val)
44
+
45
+
46
+ def get(**kwargs):
47
+ """Gets a SessionState object for the current session.
48
+ Creates a new object if necessary.
49
+ Parameters
50
+ ----------
51
+ **kwargs : any
52
+ Default values you want to add to the session state, if we're creating a
53
+ new one.
54
+ Example
55
+ -------
56
+ >>> session_state = get(user_name='', favorite_color='black')
57
+ >>> session_state.user_name
58
+ ''
59
+ >>> session_state.user_name = 'Mary'
60
+ >>> session_state.favorite_color
61
+ 'black'
62
+ Since you set user_name above, next time your script runs this will be the
63
+ result:
64
+ >>> session_state = get(user_name='', favorite_color='black')
65
+ >>> session_state.user_name
66
+ 'Mary'
67
+ """
68
+ # Hack to get the session object from Streamlit.
69
+
70
+ ctx = ReportThread.get_report_ctx()
71
+
72
+ this_session = None
73
+
74
+ current_server = Server.get_current()
75
+ if hasattr(current_server, '_session_infos'):
76
+ # Streamlit < 0.56
77
+ session_infos = Server.get_current()._session_infos.values()
78
+ else:
79
+ session_infos = Server.get_current()._session_info_by_id.values()
80
+
81
+ for session_info in session_infos:
82
+ s = session_info.session
83
+ if (
84
+ # Streamlit < 0.54.0
85
+ (hasattr(s, '_main_dg') and s._main_dg == ctx.main_dg)
86
+ or
87
+ # Streamlit >= 0.54.0
88
+ (not hasattr(s, '_main_dg') and s.enqueue == ctx.enqueue)
89
+ or
90
+ # Streamlit >= 0.65.2
91
+ (not hasattr(s, '_main_dg') and s._uploaded_file_mgr == ctx.uploaded_file_mgr)
92
+ ):
93
+ this_session = s
94
+
95
+ if this_session is None:
96
+ raise RuntimeError(
97
+ "Oh noes. Couldn't get your Streamlit Session object. "
98
+ 'Are you doing something fancy with threads?')
99
+
100
+ # Got the session object! Now let's attach some state into it.
101
+
102
+ if not hasattr(this_session, '_custom_session_state'):
103
+ this_session._custom_session_state = SessionState(**kwargs)
104
+
105
+ return this_session._custom_session_state
106
+
107
+ __all__ = ['get']
app.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import requests
3
+ from mtranslate import translate
4
+ from prompts import PROMPT_LIST
5
+ import streamlit as st
6
+ import random
7
+ from transformers import GPT2Tokenizer, GPT2LMHeadModel
8
+ import fasttext
9
+ import SessionState
10
+
11
+ LOGO = "huggingwayang.png"
12
+
13
+ MODELS = {
14
+ "GPT-2 Small": "flax-community/gpt2-small-indonesian",
15
+ "GPT-2 Medium": "flax-community/gpt2-medium-indonesian",
16
+ "GPT-2 Small Finetuned on Indonesian Journals": "Galuh/id-journal-gpt2"
17
+ }
18
+
19
+ headers = {}
20
+
21
+ @st.cache(show_spinner=False, persist=True)
22
+ def load_gpt(model_type, text):
23
+ print("Loading model...")
24
+ model = GPT2LMHeadModel.from_pretrained(MODELS[model_type])
25
+ tokenizer = GPT2Tokenizer.from_pretrained(MODELS[model_type])
26
+
27
+ return model, tokenizer
28
+
29
+ def get_image(text: str):
30
+ url = "https://wikisearch.uncool.ai/get_image/"
31
+ try:
32
+ payload = {
33
+ "text": text,
34
+ "image_width": 400
35
+ }
36
+ data = json.dumps(payload)
37
+ response = requests.request("POST", url, headers=headers, data=data)
38
+ print(response.content)
39
+ image = json.loads(response.content.decode("utf-8"))["url"]
40
+ except:
41
+ image = ""
42
+ return image
43
+
44
+ st.set_page_config(page_title="Indonesian GPT-2 Demo")
45
+
46
+ st.title("Indonesian GPT-2")
47
+
48
+ # ft_model = fasttext.load_model('lid.176.ftz')
49
+
50
+ # Sidebar
51
+ st.sidebar.image(LOGO)
52
+ st.sidebar.subheader("Configurable parameters")
53
+
54
+ max_len = st.sidebar.number_input(
55
+ "Maximum length",
56
+ value=100,
57
+ help="The maximum length of the sequence to be generated."
58
+ )
59
+
60
+ temp = st.sidebar.slider(
61
+ "Temperature",
62
+ value=1.0,
63
+ min_value=0.0,
64
+ max_value=100.0,
65
+ help="The value used to module the next token probabilities."
66
+ )
67
+
68
+ top_k = st.sidebar.number_input(
69
+ "Top k",
70
+ value=50,
71
+ help="The number of highest probability vocabulary tokens to keep for top-k-filtering."
72
+ )
73
+
74
+ top_p = st.sidebar.number_input(
75
+ "Top p",
76
+ value=1.0,
77
+ help=" If set to float < 1, only the most probable tokens with probabilities that add up to top_p or higher are kept for generation."
78
+ )
79
+
80
+ st.markdown(
81
+ """
82
+ This demo uses the [small](https://huggingface.co/flax-community/gpt2-small-indonesian) and
83
+ [medium](https://huggingface.co/flax-community/gpt2-medium-indonesian) Indonesian GPT2 model
84
+ trained on the Indonesian [Oscar](https://huggingface.co/datasets/oscar), [MC4](https://huggingface.co/datasets/mc4)
85
+ and [Wikipedia](https://huggingface.co/datasets/wikipedia) dataset. We created it as part of the
86
+ [Huggingface JAX/Flax event](https://discuss.huggingface.co/t/open-to-the-community-community-week-using-jax-flax-for-nlp-cv/).
87
+
88
+ The demo supports "multi language" ;-), feel free to try a prompt on your language. We are also experimenting with
89
+ the sentence based image search using Wikipedia passages encoded with distillbert, and search the encoded sentence
90
+ in the encoded passages using Facebook's Faiss.
91
+ """
92
+ )
93
+
94
+ model_name = st.selectbox('Model',(['GPT-2 Small', 'GPT-2 Medium', 'GPT-2 Small Finetuned on Indonesian Journals']))
95
+
96
+ if model_name in ["GPT-2 Small", "GPT-2 Medium"]:
97
+ prompt_group_name = "GPT-2"
98
+ elif model_name in ["GPT-2 Small Finetuned on Indonesian Journals"]:
99
+ prompt_group_name = "Indonesian Journals"
100
+
101
+ ALL_PROMPTS = list(PROMPT_LIST[prompt_group_name].keys())+["Custom"]
102
+ prompt = st.selectbox('Prompt', ALL_PROMPTS, index=len(ALL_PROMPTS)-1)
103
+
104
+ session_state = SessionState.get(prompt_box=None)
105
+
106
+ if prompt == "Custom":
107
+ prompt_box = "Enter your text here"
108
+ else:
109
+ prompt_box = random.choice(PROMPT_LIST[prompt_group_name][prompt])
110
+
111
+ session_state.prompt_box = prompt_box
112
+
113
+ text = st.text_area("Enter text", session_state.prompt_box)
114
+
115
+ if st.button("Run"):
116
+ with st.spinner(text="Getting results..."):
117
+
118
+ st.subheader("Result")
119
+ model, tokenizer = load_gpt(model_name, text)
120
+
121
+ input_ids = tokenizer.encode(text, return_tensors='pt')
122
+ output = model.generate(input_ids=input_ids,
123
+ max_length=max_len,
124
+ temperature=temp,
125
+ top_k=top_k,
126
+ top_p=top_p,
127
+ repetition_penalty=2.0)
128
+
129
+ text = tokenizer.decode(output[0],
130
+ skip_special_tokens=True)
131
+ st.write(text.replace("\n", " \n"))
132
+
133
+ st.text("Translation")
134
+ translation = translate(text, "en", "id")
huggingwayang.png ADDED
lid.176.ftz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8f3472cfe8738a7b6099e8e999c3cbfae0dcd15696aac7d7738a8039db603e83
3
+ size 938013
prompts.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ PROMPT_LIST = {
2
+ "GPT-2": {
3
+ "Resep masakan (recipe)": [
4
+ "Berikut adalah cara memasak sate ayam: ",
5
+ "Langkah-langkah membuat kastengel: ",
6
+ "Berikut adalah bahan-bahan membuat nastar: "
7
+ ],
8
+ "Puisi (poetry)": [
9
+ "Aku ingin jadi merpati\nTerbang di langit yang damai\nBernyanyi-nyanyi tentang masa depan\n",
10
+ "Terdiam aku satu persatu dengan tatapan binar\nSenyawa merasuk dalam sukma membuat lara\nKefanaan membentuk kelemahan"
11
+ ],
12
+ "Cerpen (short story)": [
13
+ "Putri memakai sepatunya dengan malas. Kalau bisa, selama seminggu ini ia bolos sekolah saja. Namun, Mama pasti akan marah. Ulangan tengah semester telah selesai. Minggu ini, di sekolah sedang berlangsung pekan olahraga.",
14
+ "\"Wah, hari ini cerah sekali ya,\" ucap Budi ketika ia keluar rumah.",
15
+ "Sewindu sudah kita tak berjumpa, rinduku padamu sudah tak terkira."
16
+ ],
17
+ "Sejarah (history)": [
18
+ "Mohammad Natsir adalah seorang ulama, politisi, dan pejuang kemerdekaan Indonesia.",
19
+ "Ir. H. Soekarno adalah Presiden pertama Republik Indonesia. Ia adalah seorang tokoh perjuangan yang memainkan peranan penting dalam memerdekakan bangsa Indonesia",
20
+ "Borobudur adalah sebuah candi Buddha yang terletak di sebelah barat laut Yogyakarta. Monumen ini merupakan model alam semesta dan dibangun sebagai tempat suci untuk memuliakan Buddha"
21
+ ],
22
+ "English": [
23
+ "Deoxyribonucleic acid is a molecule composed of two polynucleotide chains that coil around each other",
24
+ "Javanese is the largest of the Austronesian languages in number of native speakers"
25
+ ],
26
+ "German": [
27
+ "Eine Meerjungfrau, auch Seejungfrau oder Fischweib, ist ein weibliches Fabelwesen, ein Mischwesen aus Frauen- und Fischkörper",
28
+ "Der Mond ist der einzige natürliche Satellit der Erde"
29
+ ]},
30
+ "Indonesian Journals": {
31
+ "Biologi (biology)": ["Tujuan penelitian ini untuk menentukan keanekaragaman Arthropoda pada lahan pertanian kacang", "Identifikasi spesies secara molekuler sangat diperlukan dalam mempelajari taksonomi", "Penelitian ini bertujuan untuk menentukan identitas invertebrata laut dari Perairan Papua dengan teknik DNA barcoding"],
32
+ "Psikologi (psychology)": ["Penelitian ini bertujuan untuk mengetahui perilaku wirausaha remaja yang diprediksi dari motivasi intrinsik", "Tujuan dari penelitian ini adalah untuk mendapatkan data empiris mengenai gambaran peta bakat mahasiswa Fakultas Psikologi Unjani"],
33
+ "Ekonomi (economics)": ["Faktor kepuasan dan kepercayaan konsumen merupakan dua faktor kunci dalam meningkatkan penetrasi e-commerce. Peneltiian yang dilakukan", "Penelitian ini bertujuan untuk menganalisis pola konsumsi pangan di Indonesia", "Model GTAP diimplementasikan untuk melihat dampak yang ditimbulkan pada PDB"],
34
+ "Teknologi Informasi (IT)": ["pembuatan aplikasi ini menggunakan pengembangan metode Waterfall dan dirancang mengguynakan Unified Modeling Language (UML) dengan bahasa pemrograman", "Berdasarkan masalah tersebut, maka penulis termotivasi untuk membangun Pengembangan Sistem Informasi Manajemen"]
35
+ }
36
+ }
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ transformers
2
+ streamlit
3
+ requests==2.24.0
4
+ requests-toolbelt==0.9.1
5
+ mtranslate
6
+ -f https://download.pytorch.org/whl/torch_stable.html
7
+ torch==1.7.1+cpu; sys_platform == 'linux'
8
+ torch==1.7.1; sys_platform == 'darwin'
9
+ fasttext