|
import copy |
|
import logging |
|
from typing import List |
|
|
|
import torch |
|
import streamlit as st |
|
from transformers import BertTokenizer, TFAutoModelForMaskedLM |
|
from transformers import CamembertModel, CamembertTokenizer |
|
|
|
from rhyme_with_ai.utils import color_new_words, sanitize |
|
from rhyme_with_ai.rhyme import query_rhyme_words |
|
from rhyme_with_ai.rhyme_generator import RhymeGenerator |
|
|
|
|
|
DEFAULT_QUERY = "Machines will take over the world soon" |
|
N_RHYMES = 10 |
|
|
|
|
|
LANGUAGE = st.sidebar.radio("Language", ["english", "dutch", "french"],0) |
|
if LANGUAGE == "english": |
|
MODEL_PATH = "bert-large-cased-whole-word-masking" |
|
ITER_FACTOR = 5 |
|
elif LANGUAGE == "dutch": |
|
MODEL_PATH = "GroNLP/bert-base-dutch-cased" |
|
ITER_FACTOR = 10 |
|
elif LANGUAGE == "french": |
|
MODEL_PATH = "camembert-base" |
|
ITER_FACTOR = 5 |
|
else: |
|
raise NotImplementedError(f"Unsupported language ({LANGUAGE}) expected 'english','dutch' or 'french.") |
|
|
|
def main(): |
|
st.markdown( |
|
"<sup>Created with " |
|
"[Datamuse](https://www.datamuse.com/api/), " |
|
"[Mick's rijmwoordenboek](https://rijmwoordenboek.nl), " |
|
"[Hugging Face](https://huggingface.co/), " |
|
"[Streamlit](https://streamlit.io/) and " |
|
"[App Engine](https://cloud.google.com/appengine/)." |
|
" Read our [blog](https://blog.godatadriven.com/rhyme-with-ai) " |
|
"or check the " |
|
"[source](https://github.com/godatadriven/rhyme-with-ai).</sup>", |
|
unsafe_allow_html=True, |
|
) |
|
st.title("Rhyme with AI") |
|
query = get_query() |
|
if not query: |
|
query = DEFAULT_QUERY |
|
rhyme_words_options = query_rhyme_words(query, n_rhymes=N_RHYMES,language=LANGUAGE) |
|
if rhyme_words_options: |
|
logging.getLogger(__name__).info("Got rhyme words: %s", rhyme_words_options) |
|
start_rhyming(query, rhyme_words_options) |
|
else: |
|
st.write("No rhyme words found") |
|
|
|
|
|
def get_query(): |
|
q = sanitize( |
|
st.text_input("Write your first line and press ENTER to rhyme:", DEFAULT_QUERY) |
|
) |
|
if not q: |
|
return DEFAULT_QUERY |
|
return q |
|
|
|
|
|
def start_rhyming(query, rhyme_words_options): |
|
st.markdown("## My Suggestions:") |
|
|
|
progress_bar = st.progress(0) |
|
status_text = st.empty() |
|
max_iter = len(query.split()) * ITER_FACTOR |
|
|
|
rhyme_words = rhyme_words_options[:N_RHYMES] |
|
|
|
model, tokenizer = load_model(MODEL_PATH, LANGUAGE) |
|
sentence_generator = RhymeGenerator(model, tokenizer) |
|
sentence_generator.start(query, rhyme_words) |
|
|
|
current_sentences = [" " for _ in range(N_RHYMES)] |
|
for i in range(max_iter): |
|
previous_sentences = copy.deepcopy(current_sentences) |
|
current_sentences = sentence_generator.mutate() |
|
display_output(status_text, query, current_sentences, previous_sentences) |
|
progress_bar.progress(i / (max_iter - 1)) |
|
st.balloons() |
|
|
|
|
|
@st.cache(allow_output_mutation=True) |
|
def load_model(model_path, language): |
|
if language != "french": |
|
return ( |
|
TFAutoModelForMaskedLM.from_pretrained(model_path), |
|
BertTokenizer.from_pretrained(model_path), |
|
) |
|
else : |
|
return ( |
|
CamembertModel.from_pretrained(model_path), |
|
CamembertTokenizer.from_pretrained(model_path), |
|
) |
|
|
|
def display_output(status_text, query, current_sentences, previous_sentences): |
|
print_sentences = [] |
|
for new, old in zip(current_sentences, previous_sentences): |
|
formatted = color_new_words(new, old) |
|
after_comma = "<li>" + formatted.split(",")[1][:-2] + "</li>" |
|
print_sentences.append(after_comma) |
|
status_text.markdown( |
|
query + ",<br>" + "".join(print_sentences), unsafe_allow_html=True |
|
) |
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
logging.basicConfig(level=logging.INFO) |
|
main() |