|
import os
|
|
import dropbox
|
|
import streamlit as st
|
|
import torch
|
|
import pandas as pd
|
|
import time
|
|
from tqdm import tqdm
|
|
from simpletransformers.classification import ClassificationModel
|
|
|
|
|
|
st.title("Document Scoring App for Various Categories")
|
|
|
|
|
|
model_directories = {
|
|
'finance': 'models/finance_model',
|
|
'accounting': 'models/accounting_model',
|
|
'technology': 'models/technology_model',
|
|
'international': 'models/international_model',
|
|
'operations': 'models/operations_model',
|
|
'marketing': 'models/marketing_model',
|
|
'management': 'models/management_model',
|
|
'legal': 'models/legal_model'
|
|
}
|
|
|
|
|
|
dropbox_model_paths = {
|
|
'international': '/3) Conferences and Publications/2) Current_Projects/VIVEK/COMPLETE_REWORK/international_model',
|
|
'finance': '/3) Conferences and Publications/2) Current_Projects/VIVEK/COMPLETE_REWORK/finance_model',
|
|
'accounting': '/3) Conferences and Publications/2) Current_Projects/VIVEK/COMPLETE_REWORK/accounting_model',
|
|
'technology': '/3) Conferences and Publications/2) Current_Projects/VIVEK/COMPLETE_REWORK/technology_model',
|
|
'operations': '/3) Conferences and Publications/2) Current_Projects/VIVEK/COMPLETE_REWORK/operations_model',
|
|
'marketing': '/3) Conferences and Publications/2) Current_Projects/VIVEK/COMPLETE_REWORK/marketing_model',
|
|
'management': '/3) Conferences and Publications/2) Current_Projects/VIVEK/COMPLETE_REWORK/management_model',
|
|
'legal': '/3) Conferences and Publications/2) Current_Projects/VIVEK/COMPLETE_REWORK/legal_model'
|
|
}
|
|
|
|
|
|
dropbox_checkpoint_paths = {
|
|
'international': '/3) Conferences and Publications/2) Current_Projects/VIVEK/COMPLETE_REWORK/international_model/checkpoint-174-epoch-3',
|
|
'finance': '/3) Conferences and Publications/2) Current_Projects/VIVEK/COMPLETE_REWORK/finance_model/checkpoint-174-epoch-3',
|
|
'accounting': '/3) Conferences and Publications/2) Current_Projects/VIVEK/COMPLETE_REWORK/accounting_model/checkpoint-174-epoch-3',
|
|
'technology': '/3) Conferences and Publications/2) Current_Projects/VIVEK/COMPLETE_REWORK/technology_model/checkpoint-174-epoch-3',
|
|
'operations': '/3) Conferences and Publications/2) Current_Projects/VIVEK/COMPLETE_REWORK/operations_model/checkpoint-174-epoch-3',
|
|
'marketing': '/3) Conferences and Publications/2) Current_Projects/VIVEK/COMPLETE_REWORK/marketing_model/checkpoint-174-epoch-3',
|
|
'management': '/3) Conferences and Publications/2) Current_Projects/VIVEK/COMPLETE_REWORK/management_model/checkpoint-174-epoch-3',
|
|
'legal': '/3) Conferences and Publications/2) Current_Projects/VIVEK/COMPLETE_REWORK/legal_model/checkpoint-174-epoch-3'
|
|
}
|
|
|
|
|
|
use_cuda = torch.cuda.is_available()
|
|
|
|
|
|
def download_files_from_dropbox(dbx, dropbox_path, local_dir):
|
|
|
|
try:
|
|
for entry in dbx.files_list_folder(dropbox_path).entries:
|
|
local_path = os.path.join(local_dir, entry.name)
|
|
if isinstance(entry, dropbox.files.FileMetadata):
|
|
|
|
with open(local_path, "wb") as f:
|
|
metadata, res = dbx.files_download(path=entry.path_lower)
|
|
f.write(res.content)
|
|
elif isinstance(entry, dropbox.files.FolderMetadata):
|
|
|
|
os.makedirs(local_path, exist_ok=True)
|
|
download_files_from_dropbox(dbx, entry.path_lower, local_path)
|
|
except dropbox.exceptions.ApiError as err:
|
|
st.error(f"Dropbox API error: {err}")
|
|
|
|
|
|
def download_model(category):
|
|
model_path = model_directories[category]
|
|
if not os.path.exists(model_path):
|
|
os.makedirs(model_path, exist_ok=True)
|
|
|
|
dbx = dropbox.Dropbox(st.secrets["dropbox_api_key"])
|
|
|
|
|
|
st.write(f"Downloading {category} model...")
|
|
download_files_from_dropbox(dbx, dropbox_model_paths[category], model_path)
|
|
|
|
|
|
if category in dropbox_checkpoint_paths:
|
|
checkpoint_path = os.path.join(model_path, "checkpoint-174-epoch-3")
|
|
os.makedirs(checkpoint_path, exist_ok=True)
|
|
st.write(f"Downloading checkpoint for {category} model...")
|
|
download_files_from_dropbox(dbx, dropbox_checkpoint_paths[category], checkpoint_path)
|
|
|
|
st.success(f"{category} model and checkpoints downloaded successfully.")
|
|
|
|
|
|
def load_model(category):
|
|
model_path = model_directories[category]
|
|
|
|
download_model(category)
|
|
try:
|
|
model = ClassificationModel(
|
|
"bert",
|
|
model_path,
|
|
use_cuda=use_cuda,
|
|
args={"silent": True}
|
|
)
|
|
return model
|
|
except Exception as e:
|
|
st.error(f"Failed to load model for {category}: {e}")
|
|
return None
|
|
|
|
|
|
def score_document(model, text_data):
|
|
if isinstance(text_data, str):
|
|
text_data = [text_data]
|
|
|
|
predictions, raw_outputs = model.predict(text_data)
|
|
|
|
|
|
probability_class_1 = torch.nn.functional.softmax(torch.tensor(raw_outputs[0]), dim=0)[1].item()
|
|
|
|
return predictions[0], probability_class_1
|
|
|
|
|
|
doc_file = st.file_uploader("Upload a document (.txt)", type=["txt"])
|
|
|
|
|
|
start_time = time.time()
|
|
|
|
|
|
if doc_file is not None:
|
|
|
|
text_data = doc_file.read().decode("utf-8")
|
|
|
|
|
|
result_df = pd.DataFrame(columns=["Category", "Prediction", "Probability"])
|
|
|
|
|
|
progress_bar = st.progress(0)
|
|
total_categories = len(model_directories)
|
|
|
|
for i, category in enumerate(tqdm(model_directories.keys(), desc="Scoring documents")):
|
|
|
|
model = load_model(category)
|
|
|
|
|
|
if model is not None:
|
|
|
|
prediction, probability = score_document(model, text_data)
|
|
|
|
|
|
new_row = pd.DataFrame({
|
|
"Category": [category],
|
|
"Prediction": [prediction],
|
|
"Probability": [probability]
|
|
})
|
|
|
|
|
|
result_df = pd.concat([result_df, new_row], ignore_index=True)
|
|
|
|
|
|
progress_bar.progress((i + 1) / total_categories)
|
|
|
|
|
|
elapsed_time = time.time() - start_time
|
|
estimated_total_time = (elapsed_time / (i + 1)) * total_categories
|
|
st.write(f"Elapsed time: {elapsed_time:.2f}s, Estimated time remaining: {estimated_total_time - elapsed_time:.2f}s")
|
|
|
|
|
|
csv = result_df.to_csv(index=False).encode('utf-8')
|
|
st.download_button(
|
|
label="Download results as CSV",
|
|
data=csv,
|
|
file_name="document_scoring_results.csv",
|
|
mime="text/csv",
|
|
)
|
|
|
|
|
|
st.success("Document scoring complete!")
|
|
|
|
st.write("Note: Ensure the uploaded document is in .txt format containing text data.")
|
|
|