Spaces:
Sleeping
Sleeping
import fasttext | |
import gradio as gr | |
import joblib | |
import json as js | |
import omikuji | |
import os | |
import re | |
from collections import defaultdict | |
from huggingface_hub import snapshot_download | |
from typing import List, Tuple, Dict | |
from install_packages import download_model | |
download_model('https://dl.fbaipublicfiles.com/fasttext/supervised-models/lid.176.bin', 'lid.176.bin') | |
# Download the model files from Hugging Face | |
for repo_id in ['kapllan/omikuji-bonsai-parliament-de-spacy', 'kapllan/omikuji-bonsai-parliament-fr-spacy', | |
'kapllan/omikuji-bonsai-parliament-it-spacy']: | |
if not os.path.exists(repo_id): | |
os.makedirs(repo_id) | |
model_dir = snapshot_download(repo_id=repo_id, local_dir=repo_id) | |
lang_model = fasttext.load_model('lid.176.bin') | |
with open('./id2label.json', 'r') as f: | |
id2label = js.load(f) | |
with open('topics_hierarchy.json', 'r') as f: | |
topics_hierarchy = js.load(f) | |
def map_language(language: str) -> str: | |
language_mapping = {'de': 'German', | |
'it': 'Italian', | |
'fr': 'French'} | |
if language in language_mapping.keys(): | |
return language_mapping[language] | |
else: | |
return language | |
def find_model(language: str): | |
vectorizer, model = None, None | |
if language in ['de', 'fr', 'it']: | |
path_to_vectorizer = f'./kapllan/omikuji-bonsai-parliament-{language}-spacy/vectorizer' | |
path_to_model = f'./kapllan/omikuji-bonsai-parliament-{language}-spacy/omikuji-model' | |
vectorizer = joblib.load(path_to_vectorizer) | |
model = omikuji.Model.load(path_to_model) | |
return vectorizer, model | |
def predict_lang(text: str) -> str: | |
text = re.sub(r'\n', '', text) # Remove linebreaks because fasttext cannot process that otherwise | |
predictions = lang_model.predict(text, k=1) # returns top 2 matching languages | |
language = predictions[0][0] # returns top 2 matching languages | |
language = re.sub(r'__label__', '', language) # returns top 2 matching languages | |
return language | |
def predict_topic(text: str) -> [List[str], str]: | |
results = [] | |
language = predict_lang(text) | |
vectorizer, model = find_model(language) | |
language = map_language(language) | |
if vectorizer is not None: | |
texts = [text] | |
vector = vectorizer.transform(texts) | |
for row in vector: | |
if row.nnz == 0: # All zero vector, empty result | |
continue | |
feature_values = [(col, row[0, col]) for col in row.nonzero()[1]] | |
for subj_id, score in model.predict(feature_values, top_k=1000): | |
results.append((id2label[str(subj_id)], score)) | |
return results, language | |
def get_row_color(type: str): | |
if 'main' in type.lower(): | |
return 'background-color: darkgrey;' | |
if 'sub' in type.lower(): | |
return 'background-color: lightgrey;' | |
def generate_html_table(topics: List[Tuple[str, str, float]]): | |
html = '<table style="width:100%; border: 1px solid black; border-collapse: collapse;">' | |
html += '<tr><th>Type</th><th>Topic</th><th>Score</th></tr>' | |
for type, topic, score in topics: | |
color = get_row_color(type) | |
topic = f"<strong>{topic}</strong>" if 'main' in type.lower() else topic | |
type = f"<strong>{type}</strong>" if 'main' in type.lower() else type | |
score = f"<strong>{score}</strong>" if 'main' in type.lower() else score | |
html += f'<tr style="{color}"><td>{type}</td><td>{topic}</td><td>{score}</td></tr>' | |
html += '</table>' | |
return html | |
def restructure_topics(topics: List[Tuple[str, float]]) -> List[Tuple[str, str, float]]: | |
topics = [(str(x[0]).lower(), x[1]) for x in topics] | |
topics_as_dict = {} | |
for predicted_topic, score in topics: | |
if str(predicted_topic).lower() in topics_hierarchy.keys(): | |
topics_as_dict[str(predicted_topic).lower()] = [] | |
for predicted_topic, score in topics: | |
for main_topic, sub_topics in topics_hierarchy.items(): | |
if main_topic in topics_as_dict.keys() and predicted_topic != main_topic and predicted_topic in sub_topics: | |
topics_as_dict[main_topic].append(predicted_topic) | |
topics_restructured = [] | |
for predicted_main_topic, predicted_sub_topics in topics_as_dict.items(): | |
if len(predicted_sub_topics) > 0: | |
score = [t for t in topics if t[0] == predicted_main_topic][0][1] | |
topics_restructured.append( | |
('Main Topic', predicted_main_topic, score)) | |
predicted_sub_topics_with_scores = [] | |
for pst in predicted_sub_topics: | |
score = [t for t in topics if t[0] == pst][0][1] | |
entry = ('Sub Topic', pst, score) | |
if entry not in predicted_sub_topics_with_scores: | |
predicted_sub_topics_with_scores.append(entry) | |
for x in predicted_sub_topics_with_scores: | |
topics_restructured.append(x) | |
return topics_restructured | |
def topic_modeling(text: str, threshold: float) -> [List[str], str]: | |
# Prepare labels and scores for the plot | |
sorted_topics, language = predict_topic(text) | |
if len(sorted_topics) > 0 and language in ['German', 'French', 'Italian']: | |
sorted_topics = [t for t in sorted_topics if t[1] >= threshold] | |
else: | |
sorted_topics = [] | |
sorted_topics = restructure_topics(sorted_topics) | |
sorted_topics = generate_html_table(sorted_topics) | |
return sorted_topics, language | |
with gr.Blocks() as iface: | |
gr.Markdown("# Topic Modeling") | |
gr.Markdown("Enter a document and get each topic along with its score.") | |
with gr.Row(): | |
with gr.Column(): | |
input_text = gr.Textbox(lines=10, placeholder="Enter a document") | |
submit_button = gr.Button("Submit") | |
threshold_slider = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label="Score Threshold", | |
value=0.0) | |
language_text = gr.Textbox(lines=1, placeholder="Detected language will be shown here...", | |
interactive=False, label="Detected Language") | |
with gr.Column(): | |
output_data = gr.HTML() | |
submit_button.click(topic_modeling, inputs=[input_text, threshold_slider], | |
outputs=[output_data, language_text]) | |
# Launch the app | |
iface.launch(share=True) | |