import copy import logging from typing import List import torch import streamlit as st from transformers import BertTokenizer, TFAutoModelForMaskedLM from transformers import CamembertModel, CamembertTokenizer from rhyme_with_ai.utils import color_new_words, sanitize from rhyme_with_ai.rhyme import query_rhyme_words from rhyme_with_ai.rhyme_generator import RhymeGenerator DEFAULT_QUERY = "Machines will take over the world soon" N_RHYMES = 10 LANGUAGE = st.sidebar.radio("Language", ["english", "dutch", "french"],0) if LANGUAGE == "english": MODEL_PATH = "bert-large-cased-whole-word-masking" ITER_FACTOR = 5 elif LANGUAGE == "dutch": MODEL_PATH = "GroNLP/bert-base-dutch-cased" ITER_FACTOR = 10 # Faster model elif LANGUAGE == "french": MODEL_PATH = "camembert-base" ITER_FACTOR = 5 else: raise NotImplementedError(f"Unsupported language ({LANGUAGE}) expected 'english','dutch' or 'french.") """LANGUAGE = "french" MODEL_PATH = "camembert-base" ITER_FACTOR = 5""" def main(): st.markdown( "Created with " "[Datamuse](https://www.datamuse.com/api/), " "[Mick's rijmwoordenboek](https://rijmwoordenboek.nl), " "[Hugging Face](https://huggingface.co/), " "[Streamlit](https://streamlit.io/) and " "[App Engine](https://cloud.google.com/appengine/)." " Read our [blog](https://blog.godatadriven.com/rhyme-with-ai) " "or check the " "[source](https://github.com/godatadriven/rhyme-with-ai).", unsafe_allow_html=True, ) st.title("Rhyme with AI") query = get_query() if not query: query = DEFAULT_QUERY rhyme_words_options = query_rhyme_words(query, n_rhymes=N_RHYMES,language=LANGUAGE) if rhyme_words_options: logging.getLogger(__name__).info("Got rhyme words: %s", rhyme_words_options) start_rhyming(query, rhyme_words_options) else: st.write("No rhyme words found") def get_query(): q = sanitize( st.text_input("Write your first line and press ENTER to rhyme:", DEFAULT_QUERY) ) if not q: return DEFAULT_QUERY return q def start_rhyming(query, rhyme_words_options): st.markdown("## My Suggestions:") progress_bar = st.progress(0) status_text = st.empty() max_iter = len(query.split()) * ITER_FACTOR rhyme_words = rhyme_words_options[:N_RHYMES] model, tokenizer = load_model(MODEL_PATH, LANGUAGE) sentence_generator = RhymeGenerator(model, tokenizer) sentence_generator.start(query, rhyme_words) current_sentences = [" " for _ in range(N_RHYMES)] for i in range(max_iter): previous_sentences = copy.deepcopy(current_sentences) current_sentences = sentence_generator.mutate() display_output(status_text, query, current_sentences, previous_sentences) progress_bar.progress(i / (max_iter - 1)) st.balloons() @st.cache(allow_output_mutation=True) def load_model(model_path, language): if language != "french": return ( TFAutoModelForMaskedLM.from_pretrained(model_path), BertTokenizer.from_pretrained(model_path), ) else : tokenizer = CamembertTokenizer(vocab_file='rhyme_with_ai/dict.txt') return ( CamembertModel.from_pretrained(model_path), tokenizer.from_pretrained(model_path), ) def display_output(status_text, query, current_sentences, previous_sentences): print_sentences = [] for new, old in zip(current_sentences, previous_sentences): formatted = color_new_words(new, old) after_comma = "
  • " + formatted.split(",")[1][:-2] + "
  • " print_sentences.append(after_comma) status_text.markdown( query + ",
    " + "".join(print_sentences), unsafe_allow_html=True ) if __name__ == "__main__": logging.basicConfig(level=logging.INFO) main()