Spaces:
Runtime error
Runtime error
Create new file
Browse files
app.py
ADDED
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import copy
|
2 |
+
import logging
|
3 |
+
from typing import List
|
4 |
+
|
5 |
+
import streamlit as st
|
6 |
+
from transformers import BertTokenizer, TFAutoModelForMaskedLM
|
7 |
+
|
8 |
+
from rhyme_with_ai.utils import color_new_words, sanitize
|
9 |
+
from rhyme_with_ai.rhyme import query_rhyme_words
|
10 |
+
from rhyme_with_ai.rhyme_generator import RhymeGenerator
|
11 |
+
|
12 |
+
|
13 |
+
DEFAULT_QUERY = "Machines will take over the world soon"
|
14 |
+
N_RHYMES = 10
|
15 |
+
|
16 |
+
|
17 |
+
LANGUAGE = st.sidebar.radio("Language", ["english", "dutch"],0)
|
18 |
+
if LANGUAGE == "english":
|
19 |
+
MODEL_PATH = "bert-large-cased-whole-word-masking"
|
20 |
+
ITER_FACTOR = 5
|
21 |
+
elif LANGUAGE == "dutch":
|
22 |
+
MODEL_PATH = "GroNLP/bert-base-dutch-cased"
|
23 |
+
ITER_FACTOR = 10 # Faster model
|
24 |
+
else:
|
25 |
+
raise NotImplementedError(f"Unsupported language ({LANGUAGE}) expected 'english' or 'dutch'.")
|
26 |
+
|
27 |
+
def main():
|
28 |
+
st.markdown(
|
29 |
+
"<sup>Created with "
|
30 |
+
"[Datamuse](https://www.datamuse.com/api/), "
|
31 |
+
"[Mick's rijmwoordenboek](https://rijmwoordenboek.nl), "
|
32 |
+
"[Hugging Face](https://huggingface.co/), "
|
33 |
+
"[Streamlit](https://streamlit.io/) and "
|
34 |
+
"[App Engine](https://cloud.google.com/appengine/)."
|
35 |
+
" Read our [blog](https://blog.godatadriven.com/rhyme-with-ai) "
|
36 |
+
"or check the "
|
37 |
+
"[source](https://github.com/godatadriven/rhyme-with-ai).</sup>",
|
38 |
+
unsafe_allow_html=True,
|
39 |
+
)
|
40 |
+
st.title("Rhyme with AI")
|
41 |
+
query = get_query()
|
42 |
+
if not query:
|
43 |
+
query = DEFAULT_QUERY
|
44 |
+
rhyme_words_options = query_rhyme_words(query, n_rhymes=N_RHYMES,language=LANGUAGE)
|
45 |
+
if rhyme_words_options:
|
46 |
+
logging.getLogger(__name__).info("Got rhyme words: %s", rhyme_words_options)
|
47 |
+
start_rhyming(query, rhyme_words_options)
|
48 |
+
else:
|
49 |
+
st.write("No rhyme words found")
|
50 |
+
|
51 |
+
|
52 |
+
def get_query():
|
53 |
+
q = sanitize(
|
54 |
+
st.text_input("Write your first line and press ENTER to rhyme:", DEFAULT_QUERY)
|
55 |
+
)
|
56 |
+
if not q:
|
57 |
+
return DEFAULT_QUERY
|
58 |
+
return q
|
59 |
+
|
60 |
+
|
61 |
+
def start_rhyming(query, rhyme_words_options):
|
62 |
+
st.markdown("## My Suggestions:")
|
63 |
+
|
64 |
+
progress_bar = st.progress(0)
|
65 |
+
status_text = st.empty()
|
66 |
+
max_iter = len(query.split()) * ITER_FACTOR
|
67 |
+
|
68 |
+
rhyme_words = rhyme_words_options[:N_RHYMES]
|
69 |
+
|
70 |
+
model, tokenizer = load_model(MODEL_PATH)
|
71 |
+
sentence_generator = RhymeGenerator(model, tokenizer)
|
72 |
+
sentence_generator.start(query, rhyme_words)
|
73 |
+
|
74 |
+
current_sentences = [" " for _ in range(N_RHYMES)]
|
75 |
+
for i in range(max_iter):
|
76 |
+
previous_sentences = copy.deepcopy(current_sentences)
|
77 |
+
current_sentences = sentence_generator.mutate()
|
78 |
+
display_output(status_text, query, current_sentences, previous_sentences)
|
79 |
+
progress_bar.progress(i / (max_iter - 1))
|
80 |
+
st.balloons()
|
81 |
+
|
82 |
+
|
83 |
+
@st.cache(allow_output_mutation=True)
|
84 |
+
def load_model(model_path):
|
85 |
+
return (
|
86 |
+
TFAutoModelForMaskedLM.from_pretrained(model_path),
|
87 |
+
BertTokenizer.from_pretrained(model_path),
|
88 |
+
)
|
89 |
+
|
90 |
+
|
91 |
+
def display_output(status_text, query, current_sentences, previous_sentences):
|
92 |
+
print_sentences = []
|
93 |
+
for new, old in zip(current_sentences, previous_sentences):
|
94 |
+
formatted = color_new_words(new, old)
|
95 |
+
after_comma = "<li>" + formatted.split(",")[1][:-2] + "</li>"
|
96 |
+
print_sentences.append(after_comma)
|
97 |
+
status_text.markdown(
|
98 |
+
query + ",<br>" + "".join(print_sentences), unsafe_allow_html=True
|
99 |
+
)
|
100 |
+
|
101 |
+
|
102 |
+
|
103 |
+
if __name__ == "__main__":
|
104 |
+
logging.basicConfig(level=logging.INFO)
|
105 |
+
main()
|