import os import dropbox import streamlit as st import torch import pandas as pd import time from tqdm import tqdm from simpletransformers.classification import ClassificationModel # Set up Streamlit app st.title("Document Scoring App for Various Categories") # Model directories and corresponding Dropbox paths model_directories = { 'finance': 'models/finance_model', 'accounting': 'models/accounting_model', 'technology': 'models/technology_model', 'international': 'models/international_model', 'operations': 'models/operations_model', 'marketing': 'models/marketing_model', 'management': 'models/management_model', 'legal': 'models/legal_model' } # Dropbox paths to main model directories dropbox_model_paths = { 'international': '/3) Conferences and Publications/2) Current_Projects/VIVEK/COMPLETE_REWORK/international_model', 'finance': '/3) Conferences and Publications/2) Current_Projects/VIVEK/COMPLETE_REWORK/finance_model', 'accounting': '/3) Conferences and Publications/2) Current_Projects/VIVEK/COMPLETE_REWORK/accounting_model', 'technology': '/3) Conferences and Publications/2) Current_Projects/VIVEK/COMPLETE_REWORK/technology_model', 'operations': '/3) Conferences and Publications/2) Current_Projects/VIVEK/COMPLETE_REWORK/operations_model', 'marketing': '/3) Conferences and Publications/2) Current_Projects/VIVEK/COMPLETE_REWORK/marketing_model', 'management': '/3) Conferences and Publications/2) Current_Projects/VIVEK/COMPLETE_REWORK/management_model', 'legal': '/3) Conferences and Publications/2) Current_Projects/VIVEK/COMPLETE_REWORK/legal_model' } # Dropbox paths to model checkpoints (all 8 models) dropbox_checkpoint_paths = { 'international': '/3) Conferences and Publications/2) Current_Projects/VIVEK/COMPLETE_REWORK/international_model/checkpoint-174-epoch-3', 'finance': '/3) Conferences and Publications/2) Current_Projects/VIVEK/COMPLETE_REWORK/finance_model/checkpoint-174-epoch-3', 'accounting': '/3) Conferences and Publications/2) Current_Projects/VIVEK/COMPLETE_REWORK/accounting_model/checkpoint-174-epoch-3', 'technology': '/3) Conferences and Publications/2) Current_Projects/VIVEK/COMPLETE_REWORK/technology_model/checkpoint-174-epoch-3', 'operations': '/3) Conferences and Publications/2) Current_Projects/VIVEK/COMPLETE_REWORK/operations_model/checkpoint-174-epoch-3', 'marketing': '/3) Conferences and Publications/2) Current_Projects/VIVEK/COMPLETE_REWORK/marketing_model/checkpoint-174-epoch-3', 'management': '/3) Conferences and Publications/2) Current_Projects/VIVEK/COMPLETE_REWORK/management_model/checkpoint-174-epoch-3', 'legal': '/3) Conferences and Publications/2) Current_Projects/VIVEK/COMPLETE_REWORK/legal_model/checkpoint-174-epoch-3' } # Check if CUDA is available use_cuda = torch.cuda.is_available() # Function to download files from Dropbox recursively, including checkpoint directories def download_files_from_dropbox(dbx, dropbox_path, local_dir): # List all files and subfolders in the Dropbox path try: for entry in dbx.files_list_folder(dropbox_path).entries: local_path = os.path.join(local_dir, entry.name) if isinstance(entry, dropbox.files.FileMetadata): # It's a file, download it with open(local_path, "wb") as f: metadata, res = dbx.files_download(path=entry.path_lower) f.write(res.content) elif isinstance(entry, dropbox.files.FolderMetadata): # It's a folder, create it locally and download its contents os.makedirs(local_path, exist_ok=True) download_files_from_dropbox(dbx, entry.path_lower, local_path) except dropbox.exceptions.ApiError as err: st.error(f"Dropbox API error: {err}") # Function to download models and checkpoints from Dropbox def download_model(category): model_path = model_directories[category] if not os.path.exists(model_path): os.makedirs(model_path, exist_ok=True) dbx = dropbox.Dropbox(st.secrets["dropbox_api_key"]) # Download the main model files st.write(f"Downloading {category} model...") download_files_from_dropbox(dbx, dropbox_model_paths[category], model_path) # Download the checkpoint files if available if category in dropbox_checkpoint_paths: checkpoint_path = os.path.join(model_path, "checkpoint-174-epoch-3") os.makedirs(checkpoint_path, exist_ok=True) st.write(f"Downloading checkpoint for {category} model...") download_files_from_dropbox(dbx, dropbox_checkpoint_paths[category], checkpoint_path) st.success(f"{category} model and checkpoints downloaded successfully.") # Function to load a model, skipping if it can't be loaded def load_model(category): model_path = model_directories[category] # Ensure the model is downloaded download_model(category) try: model = ClassificationModel( "bert", model_path, use_cuda=use_cuda, args={"silent": True} # Suppress output ) return model except Exception as e: st.error(f"Failed to load model for {category}: {e}") return None # Function to score a document and return the prediction and probability for class '1' def score_document(model, text_data): if isinstance(text_data, str): text_data = [text_data] predictions, raw_outputs = model.predict(text_data) # Get the probability associated with class '1' probability_class_1 = torch.nn.functional.softmax(torch.tensor(raw_outputs[0]), dim=0)[1].item() return predictions[0], probability_class_1 # Let the user upload a file doc_file = st.file_uploader("Upload a document (.txt)", type=["txt"]) # Track the start time start_time = time.time() # Make predictions when a file is uploaded if doc_file is not None: # Read the content of the uploaded .txt file text_data = doc_file.read().decode("utf-8") # Initialize an empty DataFrame for results result_df = pd.DataFrame(columns=["Category", "Prediction", "Probability"]) # Progress bar progress_bar = st.progress(0) total_categories = len(model_directories) for i, category in enumerate(tqdm(model_directories.keys(), desc="Scoring documents")): # Load the pre-trained model for the current category model = load_model(category) # Skip the category if model loading fails if model is not None: # Score the document prediction, probability = score_document(model, text_data) # Create a DataFrame for the current result new_row = pd.DataFrame({ "Category": [category], "Prediction": [prediction], "Probability": [probability] }) # Use pd.concat to append the new row to the DataFrame result_df = pd.concat([result_df, new_row], ignore_index=True) # Update the progress bar progress_bar.progress((i + 1) / total_categories) # Estimate remaining time elapsed_time = time.time() - start_time estimated_total_time = (elapsed_time / (i + 1)) * total_categories st.write(f"Elapsed time: {elapsed_time:.2f}s, Estimated time remaining: {estimated_total_time - elapsed_time:.2f}s") # Save results to CSV csv = result_df.to_csv(index=False).encode('utf-8') st.download_button( label="Download results as CSV", data=csv, file_name="document_scoring_results.csv", mime="text/csv", ) # Display completion message st.success("Document scoring complete!") st.write("Note: Ensure the uploaded document is in .txt format containing text data.")