thbud67 / app.py
napatswift
Update number of results slider in app.py
64a9264
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AutoModelForCausalLM
import torch
import pandas as pd
import numpy as np
import re
import gradio as gr
model_repo = "napatswift/mt5-fixpdftext"
tokenizer = AutoTokenizer.from_pretrained(model_repo)
model = AutoModelForSeq2SeqLM.from_pretrained(model_repo)
embedding = list(model.modules())[1]
del model
def get_embedding(text):
return embedding(tokenizer(text, return_tensors='pt').input_ids[0]).mean(axis=0)
df = pd.read_csv('67_all_ministry.csv')
def get_name(row):
for col, val in row.items():
if col.startswith('name_') and val and isinstance(val, str):
return val
return
budget_items = df.apply(get_name, axis=1).unique().tolist()
budget_item_embeddings = torch.stack(list(map(get_embedding, budget_items)))
def get_closest_budget_item(text, num_results=5):
text_embedding = get_embedding(text)
scores = torch.norm(budget_item_embeddings - text_embedding, dim=1)
top_idx = scores.argsort()[:num_results]
return pd.DataFrame({
'budget_item': np.array(budget_items)[top_idx],
'score': scores[top_idx].tolist()
})
demo = gr.Interface(
fn=get_closest_budget_item,
inputs=['textbox', gr.Slider(minimum=1, maximum=50, step=5, value=5, label="Number of results")],
outputs='dataframe',
)
if __name__ == "__main__":
demo.launch()