# coding=utf-8 # Copyright 2023 The GlotLID Authors. # Lint as: python3 # This space is built based on AMR-KELEG/ALDi space. # GlotLID Space import string import constants import pandas as pd import streamlit as st from huggingface_hub import hf_hub_download from GlotScript import get_script_predictor import matplotlib.pyplot as plt import fasttext import altair as alt from altair import X, Y, Scale import base64 import json import os import re @st.cache_resource def load_sp(): sp = get_script_predictor() return sp sp = load_sp() def get_script(text): """Get the writing systems of given text. Args: text: The text to be preprocessed. Returns: The main script and list of all scripts. """ res = sp(text) main_script = res[0] if res[0] else 'Zyyy' all_scripts_dict = res[2]['details'] if all_scripts_dict: all_scripts = list(all_scripts_dict.keys()) else: all_scripts = 'Zyyy' for ws in all_scripts: if ws in ['Kana', 'Hrkt', 'Hani', 'Hira']: all_scripts.append('Jpan') all_scripts = list(set(all_scripts)) return main_script, all_scripts def preprocess_text(text): """Apply preprocessing to the given text. Args: text: Thetext to be preprocessed. Returns: The preprocessed text. """ # remove \n text = text.replace('\n', ' ') # get rid of characters that are ubiquitous replace_by = " " replacement_map = { ord(c): replace_by for c in ':#{|}' + string.digits } text = text.translate(replacement_map) # make multiple space one space text = re.sub(r'\s+', ' ', text) # strip the text text = text.strip() return text @st.cache_data def language_names(json_path): with open(json_path, 'r') as json_file: data = json.load(json_file) return data label2name = language_names("assets/language_names.json") def get_name(label): """Get the name of language from label""" iso_3 = label.split('_')[0] name = label2name[iso_3] return name @st.cache_data def render_svg(svg): """Renders the given svg string.""" b64 = base64.b64encode(svg.encode("utf-8")).decode("utf-8") html = rf'
' c = st.container() c.write(html, unsafe_allow_html=True) @st.cache_data def render_metadata(): """Renders the metadata.""" html = r"""""" c = st.container() c.write(html, unsafe_allow_html=True) @st.cache_data def citation(): """Renders the metadata.""" _CITATION = """ @inproceedings{ kargaran2023glotlid, title={GlotLID: Language Identification for Low-Resource Languages}, author={Kargaran, Amir Hossein and Imani, Ayyoob and Yvon, Fran{\c{c}}ois and Sch{\"u}tze, Hinrich}, booktitle={The 2023 Conference on Empirical Methods in Natural Language Processing}, year={2023}, url={https://openreview.net/forum?id=dl4e3EBz5j} }""" st.code(_CITATION, language="python", line_numbers=False) @st.cache_data def convert_df(df): # IMPORTANT: Cache the conversion to prevent computation on every rerun return df.to_csv(index=None).encode("utf-8") @st.cache_resource def load_GlotLID_v1(model_name, file_name): model_path = hf_hub_download(repo_id=model_name, filename=file_name) model = fasttext.load_model(model_path) return model @st.cache_resource def load_GlotLID_v2(model_name, file_name): model_path = hf_hub_download(repo_id=model_name, filename=file_name) model = fasttext.load_model(model_path) return model @st.cache_resource def load_OpenLID(): model_path = hf_hub_download(repo_id='laurievb/OpenLID', filename='model.bin') model = fasttext.load_model(model_path) return model @st.cache_resource def load_NLLB(): model_path = hf_hub_download(repo_id='facebook/fasttext-language-identification', filename='model.bin') model = fasttext.load_model(model_path) return model model_1 = load_GlotLID_v1(constants.MODEL_NAME, "model_v1.bin") model_2 = load_GlotLID_v2(constants.MODEL_NAME, "model_v2.bin") model_3 = load_OpenLID() model_4 = load_NLLB() # @st.cache_resource def plot(label, prob): ORANGE_COLOR = "#FF8000" BLACK_COLOR = "#31333F" fig, ax = plt.subplots(figsize=(8, 1)) fig.patch.set_facecolor("none") ax.set_facecolor("none") ax.spines["left"].set_color(BLACK_COLOR) ax.spines["bottom"].set_color(BLACK_COLOR) ax.tick_params(axis="x", colors=BLACK_COLOR) ax.spines[["right", "top"]].set_visible(False) ax.barh(y=[0], width=[prob], color=ORANGE_COLOR) ax.set_xlim(0, 1) ax.set_ylim(-1, 1) ax.set_title(f"Label: {label}, Language: {get_name(label)}", color=BLACK_COLOR) ax.get_yaxis().set_visible(False) ax.set_xlabel("Confidence", color=BLACK_COLOR) st.pyplot(fig) def compute(sentences, version = 'v2'): """Computes the language probablities and labels for the given sentences. Args: sentences: A list of sentences. Returns: A list of language probablities and labels for the given sentences. """ progress_text = "Computing Language..." if version == 'nllb-218': model_choice = model_4 elif version == 'openlid-201': model_choice = model_3 elif version == 'v2': model_choice = model_2 else: model_choice = model_1 my_bar = st.progress(0, text=progress_text) probs = [] labels = [] sentences = [preprocess_text(sent) for sent in sentences] for index, sent in enumerate(sentences): output = model_choice.predict(sent) output_label = output[0][0].split('__')[-1] output_prob = max(min(output[1][0], 1), 0) output_label_language = output_label.split('_')[0] # script control if version in ['v2', 'openlid-201', 'nllb-218'] and output_label_language!= 'zxx': main_script, all_scripts = get_script(sent) output_label_script = output_label.split('_')[1] if output_label_script not in all_scripts: output_label_script = main_script output_label = f"und_{output_label_script}" output_prob = 0 labels = labels + [output_label] probs = probs + [output_prob] my_bar.progress( min((index) / len(sentences), 1), text=progress_text, ) my_bar.empty() return probs, labels # st.markdown("[![Duplicate Space](https://img.shields.io/badge/-Duplicate%20Space-blue?labelColor=white&style=flat&logo=&logoWidth=14)](https://huggingface.co/spaces/cis-lmu/glotlid-space?duplicate=true)") # render_svg(open("assets/glotlid_logo.svg").read()) render_metadata() st.markdown("**GlotLID** is an open-source language identification model with support for more than **1600 languages**.") tab1, tab2 = st.tabs(["Input a Sentence", "Upload a File"]) with tab1: # choice = st.radio( # "Set granularity level", # ["default", "merge", "individual"], # captions=["enable both macrolanguage and its varieties (default)", "merge macrolanguage and its varieties into one label", "remove macrolanguages - only shows individual langauges"], # ) version = st.radio( "Choose model", ["nllb-218", "openlid-201", "v1", "v2"], captions=["NLLB", "OpenLID", "GlotLID version 1", "GlotLID version 2 (more data and languages)"], index = 1, key = 'version_tab1', horizontal = True ) sent = st.text_input( "Sentence:", placeholder="Enter a sentence.", on_change=None ) # TODO: Check if this is needed! clicked = st.button("Submit") if sent: probs, labels = compute([sent], version=version) prob = probs[0] label = labels[0] # Check if the file exists if not os.path.exists('logs.txt'): with open('logs.txt', 'w') as file: pass print(f"{sent}, {label}: {prob}") with open("logs.txt", "a") as f: f.write(f"{sent}, {label}: {prob}\n") # plot plot(label, prob) with tab2: version = st.radio( "Choose model", ["v1", "v2"], captions=["GlotLID version 1", "GlotLID version 2 (more data and languages)"], index = 1, key = 'version_tab2', horizontal = True ) file = st.file_uploader("Upload a file", type=["txt"]) if file is not None: df = pd.read_csv(file, sep="¦\t¦", header=None, engine='python') df.columns = ["Sentence"] df.reset_index(drop=True, inplace=True) # TODO: Run the model df['Prob'], df["Label"] = compute(df["Sentence"].tolist(), version= version) df['Language'] = df["Label"].apply(get_name) # A horizontal rule st.markdown("""---""") chart = ( alt.Chart(df.reset_index()) .mark_area(color="darkorange", opacity=0.5) .encode( x=X(field="index", title="Sentence Index"), y=Y("Prob", scale=Scale(domain=[0, 1])), ) ) st.altair_chart(chart.interactive(), use_container_width=True) col1, col2 = st.columns([4, 1]) with col1: # Display the output st.table( df, ) with col2: # Add a download button csv = convert_df(df) st.download_button( label=":file_folder: Download predictions as CSV", data=csv, file_name="GlotLID.csv", mime="text/csv", ) # citation()