mgmtprofessor's picture
Upload 2 files
38d73ce verified
raw
history blame
8.12 kB
import os
import dropbox
import streamlit as st
import torch
import pandas as pd
import time
from tqdm import tqdm
from simpletransformers.classification import ClassificationModel
# Set up Streamlit app
st.title("Document Scoring App for Various Categories")
# Model directories and corresponding Dropbox paths
model_directories = {
'finance': 'models/finance_model',
'accounting': 'models/accounting_model',
'technology': 'models/technology_model',
'international': 'models/international_model',
'operations': 'models/operations_model',
'marketing': 'models/marketing_model',
'management': 'models/management_model',
'legal': 'models/legal_model'
}
# Dropbox paths to main model directories
dropbox_model_paths = {
'international': '/3) Conferences and Publications/2) Current_Projects/VIVEK/COMPLETE_REWORK/international_model',
'finance': '/3) Conferences and Publications/2) Current_Projects/VIVEK/COMPLETE_REWORK/finance_model',
'accounting': '/3) Conferences and Publications/2) Current_Projects/VIVEK/COMPLETE_REWORK/accounting_model',
'technology': '/3) Conferences and Publications/2) Current_Projects/VIVEK/COMPLETE_REWORK/technology_model',
'operations': '/3) Conferences and Publications/2) Current_Projects/VIVEK/COMPLETE_REWORK/operations_model',
'marketing': '/3) Conferences and Publications/2) Current_Projects/VIVEK/COMPLETE_REWORK/marketing_model',
'management': '/3) Conferences and Publications/2) Current_Projects/VIVEK/COMPLETE_REWORK/management_model',
'legal': '/3) Conferences and Publications/2) Current_Projects/VIVEK/COMPLETE_REWORK/legal_model'
}
# Dropbox paths to model checkpoints (all 8 models)
dropbox_checkpoint_paths = {
'international': '/3) Conferences and Publications/2) Current_Projects/VIVEK/COMPLETE_REWORK/international_model/checkpoint-174-epoch-3',
'finance': '/3) Conferences and Publications/2) Current_Projects/VIVEK/COMPLETE_REWORK/finance_model/checkpoint-174-epoch-3',
'accounting': '/3) Conferences and Publications/2) Current_Projects/VIVEK/COMPLETE_REWORK/accounting_model/checkpoint-174-epoch-3',
'technology': '/3) Conferences and Publications/2) Current_Projects/VIVEK/COMPLETE_REWORK/technology_model/checkpoint-174-epoch-3',
'operations': '/3) Conferences and Publications/2) Current_Projects/VIVEK/COMPLETE_REWORK/operations_model/checkpoint-174-epoch-3',
'marketing': '/3) Conferences and Publications/2) Current_Projects/VIVEK/COMPLETE_REWORK/marketing_model/checkpoint-174-epoch-3',
'management': '/3) Conferences and Publications/2) Current_Projects/VIVEK/COMPLETE_REWORK/management_model/checkpoint-174-epoch-3',
'legal': '/3) Conferences and Publications/2) Current_Projects/VIVEK/COMPLETE_REWORK/legal_model/checkpoint-174-epoch-3'
}
# Check if CUDA is available
use_cuda = torch.cuda.is_available()
# Function to download files from Dropbox recursively, including checkpoint directories
def download_files_from_dropbox(dbx, dropbox_path, local_dir):
# List all files and subfolders in the Dropbox path
try:
for entry in dbx.files_list_folder(dropbox_path).entries:
local_path = os.path.join(local_dir, entry.name)
if isinstance(entry, dropbox.files.FileMetadata):
# It's a file, download it
with open(local_path, "wb") as f:
metadata, res = dbx.files_download(path=entry.path_lower)
f.write(res.content)
elif isinstance(entry, dropbox.files.FolderMetadata):
# It's a folder, create it locally and download its contents
os.makedirs(local_path, exist_ok=True)
download_files_from_dropbox(dbx, entry.path_lower, local_path)
except dropbox.exceptions.ApiError as err:
st.error(f"Dropbox API error: {err}")
# Function to download models and checkpoints from Dropbox
def download_model(category):
model_path = model_directories[category]
if not os.path.exists(model_path):
os.makedirs(model_path, exist_ok=True)
dbx = dropbox.Dropbox(st.secrets["dropbox_api_key"])
# Download the main model files
st.write(f"Downloading {category} model...")
download_files_from_dropbox(dbx, dropbox_model_paths[category], model_path)
# Download the checkpoint files if available
if category in dropbox_checkpoint_paths:
checkpoint_path = os.path.join(model_path, "checkpoint-174-epoch-3")
os.makedirs(checkpoint_path, exist_ok=True)
st.write(f"Downloading checkpoint for {category} model...")
download_files_from_dropbox(dbx, dropbox_checkpoint_paths[category], checkpoint_path)
st.success(f"{category} model and checkpoints downloaded successfully.")
# Function to load a model, skipping if it can't be loaded
def load_model(category):
model_path = model_directories[category]
# Ensure the model is downloaded
download_model(category)
try:
model = ClassificationModel(
"bert",
model_path,
use_cuda=use_cuda,
args={"silent": True} # Suppress output
)
return model
except Exception as e:
st.error(f"Failed to load model for {category}: {e}")
return None
# Function to score a document and return the prediction and probability for class '1'
def score_document(model, text_data):
if isinstance(text_data, str):
text_data = [text_data]
predictions, raw_outputs = model.predict(text_data)
# Get the probability associated with class '1'
probability_class_1 = torch.nn.functional.softmax(torch.tensor(raw_outputs[0]), dim=0)[1].item()
return predictions[0], probability_class_1
# Let the user upload a file
doc_file = st.file_uploader("Upload a document (.txt)", type=["txt"])
# Track the start time
start_time = time.time()
# Make predictions when a file is uploaded
if doc_file is not None:
# Read the content of the uploaded .txt file
text_data = doc_file.read().decode("utf-8")
# Initialize an empty DataFrame for results
result_df = pd.DataFrame(columns=["Category", "Prediction", "Probability"])
# Progress bar
progress_bar = st.progress(0)
total_categories = len(model_directories)
for i, category in enumerate(tqdm(model_directories.keys(), desc="Scoring documents")):
# Load the pre-trained model for the current category
model = load_model(category)
# Skip the category if model loading fails
if model is not None:
# Score the document
prediction, probability = score_document(model, text_data)
# Create a DataFrame for the current result
new_row = pd.DataFrame({
"Category": [category],
"Prediction": [prediction],
"Probability": [probability]
})
# Use pd.concat to append the new row to the DataFrame
result_df = pd.concat([result_df, new_row], ignore_index=True)
# Update the progress bar
progress_bar.progress((i + 1) / total_categories)
# Estimate remaining time
elapsed_time = time.time() - start_time
estimated_total_time = (elapsed_time / (i + 1)) * total_categories
st.write(f"Elapsed time: {elapsed_time:.2f}s, Estimated time remaining: {estimated_total_time - elapsed_time:.2f}s")
# Save results to CSV
csv = result_df.to_csv(index=False).encode('utf-8')
st.download_button(
label="Download results as CSV",
data=csv,
file_name="document_scoring_results.csv",
mime="text/csv",
)
# Display completion message
st.success("Document scoring complete!")
st.write("Note: Ensure the uploaded document is in .txt format containing text data.")