rhyme-ai / app.py
Camille
fix: torch added
ef1f66e
raw
history blame
3.76 kB
import copy
import logging
from typing import List
import torch
import streamlit as st
from transformers import BertTokenizer, TFAutoModelForMaskedLM
from transformers import CamembertModel, CamembertTokenizer
from rhyme_with_ai.utils import color_new_words, sanitize
from rhyme_with_ai.rhyme import query_rhyme_words
from rhyme_with_ai.rhyme_generator import RhymeGenerator
DEFAULT_QUERY = "Machines will take over the world soon"
N_RHYMES = 10
LANGUAGE = st.sidebar.radio("Language", ["english", "dutch", "french"],0)
if LANGUAGE == "english":
MODEL_PATH = "bert-large-cased-whole-word-masking"
ITER_FACTOR = 5
elif LANGUAGE == "dutch":
MODEL_PATH = "GroNLP/bert-base-dutch-cased"
ITER_FACTOR = 10 # Faster model
elif LANGUAGE == "french":
MODEL_PATH = "camembert-base"
ITER_FACTOR = 5
else:
raise NotImplementedError(f"Unsupported language ({LANGUAGE}) expected 'english','dutch' or 'french.")
def main():
st.markdown(
"<sup>Created with "
"[Datamuse](https://www.datamuse.com/api/), "
"[Mick's rijmwoordenboek](https://rijmwoordenboek.nl), "
"[Hugging Face](https://huggingface.co/), "
"[Streamlit](https://streamlit.io/) and "
"[App Engine](https://cloud.google.com/appengine/)."
" Read our [blog](https://blog.godatadriven.com/rhyme-with-ai) "
"or check the "
"[source](https://github.com/godatadriven/rhyme-with-ai).</sup>",
unsafe_allow_html=True,
)
st.title("Rhyme with AI")
query = get_query()
if not query:
query = DEFAULT_QUERY
rhyme_words_options = query_rhyme_words(query, n_rhymes=N_RHYMES,language=LANGUAGE)
if rhyme_words_options:
logging.getLogger(__name__).info("Got rhyme words: %s", rhyme_words_options)
start_rhyming(query, rhyme_words_options)
else:
st.write("No rhyme words found")
def get_query():
q = sanitize(
st.text_input("Write your first line and press ENTER to rhyme:", DEFAULT_QUERY)
)
if not q:
return DEFAULT_QUERY
return q
def start_rhyming(query, rhyme_words_options):
st.markdown("## My Suggestions:")
progress_bar = st.progress(0)
status_text = st.empty()
max_iter = len(query.split()) * ITER_FACTOR
rhyme_words = rhyme_words_options[:N_RHYMES]
model, tokenizer = load_model(MODEL_PATH, LANGUAGE)
sentence_generator = RhymeGenerator(model, tokenizer)
sentence_generator.start(query, rhyme_words)
current_sentences = [" " for _ in range(N_RHYMES)]
for i in range(max_iter):
previous_sentences = copy.deepcopy(current_sentences)
current_sentences = sentence_generator.mutate()
display_output(status_text, query, current_sentences, previous_sentences)
progress_bar.progress(i / (max_iter - 1))
st.balloons()
@st.cache(allow_output_mutation=True)
def load_model(model_path, language):
if language != "french":
return (
TFAutoModelForMaskedLM.from_pretrained(model_path),
BertTokenizer.from_pretrained(model_path),
)
else :
return (
CamembertModel.from_pretrained(model_path),
CamembertTokenizer.from_pretrained(model_path),
)
def display_output(status_text, query, current_sentences, previous_sentences):
print_sentences = []
for new, old in zip(current_sentences, previous_sentences):
formatted = color_new_words(new, old)
after_comma = "<li>" + formatted.split(",")[1][:-2] + "</li>"
print_sentences.append(after_comma)
status_text.markdown(
query + ",<br>" + "".join(print_sentences), unsafe_allow_html=True
)
if __name__ == "__main__":
logging.basicConfig(level=logging.INFO)
main()