import gradio as gr import pandas as pd import numpy as np from sklearn.model_selection import train_test_split, cross_val_score from tabpfn_client import TabPFNClassifier from tabpfn_client.tabpfn_common_utils import utils as common_utils from tabpfn_client.client import ServiceClient from password_strength import PasswordPolicy import textwrap import io import logging import sys import tempfile from pathlib import Path import time from tabpfn_client import init from sklearn.metrics import mean_squared_error, r2_score client = ServiceClient() access_token = None TERMS_OF_SERVICE_URL = "https://www.priorlabs.ai/terms-eu-en" is_logged_in = False class PromptAgent: @staticmethod def indent(text: str): indent_factor = 2 indent_str = " " * indent_factor return textwrap.indent(text, indent_str) @staticmethod def password_req_to_policy(password_req: list[str]): requirements = {} for req in password_req: word_part, number_part = req.split("(") number = int(number_part[:-1]) requirements[word_part.lower()] = number return PasswordPolicy.from_names(**requirements) def login(email, password): global access_token access_token, message = client.login(email, password) if access_token: client.authorize(access_token) gr.Info("Login successful!") return gr.update(visible=False), gr.update(visible=True), gr.update(visible=True) else: gr.Warning(f"Login failed: {message}") return gr.update(visible=True), gr.update(visible=False), gr.update(visible=False) def register(email, password, password_confirm, first_name, last_name, organization, role, use_case, contact_via_email, tos_agreed): global access_token if not tos_agreed: gr.Warning("Registration failed: You must agree to the Terms of Service.") return gr.update(visible=True), gr.update(visible=False), gr.update(visible=False) is_valid, message = client.validate_email(email) if not is_valid: gr.Warning(f"Registration failed: Invalid email - {message}") return gr.update(visible=True), gr.update(visible=False), gr.update(visible=False) password_req = client.get_password_policy() password_policy = PromptAgent.password_req_to_policy(password_req) if len(password_policy.test(password)) != 0: gr.Warning("Registration failed: Password requirements not satisfied. Please check the password policy.") return gr.update(visible=True), gr.update(visible=False), gr.update(visible=False) if password != password_confirm: gr.Warning("Registration failed: Passwords do not match.") return gr.update(visible=True), gr.update(visible=False), gr.update(visible=False) validation_link = "tabpfn-2023" additional_info = { "first_name": first_name, "last_name": last_name, "company": organization, "role": role, "use_case": use_case, "contact_via_email": contact_via_email, } is_created, message, access_token = client.register(email, password, password_confirm, validation_link, additional_info) if is_created: client.authorize(access_token) gr.Info("Registration successful! Please check your email for a verification link.") return gr.update(visible=False), gr.update(visible=True), gr.update(visible=True) else: gr.Warning(f"Registration failed: {message}") return gr.update(visible=True), gr.update(visible=False), gr.update(visible=False) def get_password_policy(): policy = client.get_password_policy() return "\n".join([f"- {req}" for req in policy]) def list_datasets(): try: data_summary = config.g_tabpfn_config.user_auth_handler.service_client.get_data_summary() # Extract relevant information from the data summary datasets = data_summary.get('datasets_summary', []) # Create a list to hold the formatted dataset information formatted_datasets = [] for dataset in datasets: train_set = { 'Dataset Type': 'Train Set', 'UID': dataset['train_set_uid'], 'Added On': dataset['datetime_added'], 'X Filename': dataset['x_train_filename'], 'Y Filename': dataset['y_train_filename'] } formatted_datasets.append(train_set) for test_set in dataset.get('associated_test_sets', []): test_set_info = { 'Dataset Type': 'Test Set', 'UID': test_set['test_set_uid'], 'Added On': test_set['datetime_added'], 'X Filename': test_set['x_test_filename'], 'Y Filename': 'N/A' # Test sets don't have y_test_filename } formatted_datasets.append(test_set_info) # Create a DataFrame from the formatted dataset information df = pd.DataFrame(formatted_datasets) # If the DataFrame is empty, return a message instead if df.empty: return gr.Dataframe(value=[["No datasets found"]], visible=True) return gr.Dataframe(value=df, visible=True) except Exception as e: gr.Error(f"Error listing datasets: {str(e)}") return gr.Dataframe(value=[["Error retrieving datasets"]], visible=True) def delete_dataset(dataset_uid, confirm): if not confirm: gr.Warning("Please confirm the deletion by checking the confirmation box.") return None try: deleted_uids = config.g_tabpfn_config.user_auth_handler.service_client.delete_dataset(dataset_uid) gr.Info(f"Successfully deleted dataset(s): {', '.join(deleted_uids)}") return list_datasets() except Exception as e: gr.Error(f"Error deleting dataset: {str(e)}") return None def delete_account(confirm_password, confirm): if not confirm: gr.Warning("Please confirm the account deletion by checking the confirmation box.") return gr.update(visible=False), gr.update(visible=True), gr.update(visible=True) try: config.g_tabpfn_config.user_auth_handler.service_client.delete_user_account(confirm_password) gr.Info("Account deleted successfully.") # Return updates to make login tab visible and others invisible return gr.update(visible=True), gr.update(visible=False), gr.update(visible=False) except Exception as e: gr.Error(f"Error deleting account: {str(e)}") # Keep current tab visibility if there's an error return gr.update(visible=False), gr.update(visible=True), gr.update(visible=True) def download_all_data(): try: temp_dir = tempfile.mkdtemp() save_path = config.g_tabpfn_config.user_auth_handler.service_client.download_all_data(temp_dir) gr.Info("All data downloaded successfully.") # Return the file with updated visibility return gr.File(value=str(save_path), visible=True) except Exception as e: gr.Error(f"Error downloading data: {str(e)}") return None def logout(): try: config.g_tabpfn_config.user_auth_handler.service_client.reset_authorization() gr.Info("Logged out successfully.") return gr.update(visible=True), gr.update(visible=False), gr.update(visible=False) except Exception as e: gr.Error(f"Error during logout: {str(e)}") return gr.update(visible=False), gr.update(visible=True), gr.update(visible=True) from sklearn.model_selection import cross_validate def estimate_performance(df, feature_cols, target_col, task, progress=gr.Progress()): global access_token if access_token is None: gr.Warning("Please log in or register first.") return None try: progress(0, desc="Preparing data") X = df[feature_cols] y = df[target_col] # Remove rows with missing labels mask = ~y.isnull() X = X[mask] y = y[mask] progress(0.1, desc="Initializing model") if task == "classification": if y.dtype == 'object': y = pd.Categorical(y).codes if len(np.unique(y)) < 2: gr.Warning("The dataset must have at least two different categories in the target column for classification.") return None model = TabPFNClassifier() scoring = {'accuracy': 'accuracy'} elif task == "regression": model = TabPFNRegressor() scoring = {'mse': 'neg_mean_squared_error', 'r2': 'r2'} else: gr.Error("Invalid task type. Please choose either 'classification' or 'regression'.") return None progress(0.2, desc="Performing cross-validation") cv_results = cross_validate(model, X, y, cv=5, scoring=scoring, n_jobs=-1, return_train_score=False) progress(0.9, desc="Generating report") if task == "classification": scores = cv_results['test_accuracy'] result = ( f"Average Accuracy: {np.mean(scores):.2%}\n\n" "Individual Test Results:\n" f"{', '.join([f'{s:.2%}' for s in scores])}\n\n" "What does this mean?\n" "- We tested the model's performance by dividing your data into 5 parts.\n" "- For each part, we trained the model on 4 parts and tested it on the remaining part.\n" "- We repeated this process 5 times, each time using a different part for testing.\n" "- The average accuracy shows how often the model correctly predicted the category.\n" "- Some variation in these numbers is normal and expected.\n\n" "Note: This is an estimate of how well the model might perform on new, unseen data." ) else: # regression mse_scores = -cv_results['test_mse'] # Negative because sklearn returns negative MSE r2_scores = cv_results['test_r2'] result = ( f"Average Mean Squared Error: {np.mean(mse_scores):.4f}\n" f"Average R-squared: {np.mean(r2_scores):.4f}\n\n" "Individual Test Results:\n" f"MSE: {', '.join([f'{s:.4f}' for s in mse_scores])}\n" f"R-squared: {', '.join([f'{s:.4f}' for s in r2_scores])}\n\n" "What does this mean?\n" "- We tested the model's performance by dividing your data into 5 parts.\n" "- For each part, we trained the model on 4 parts and tested it on the remaining part.\n" "- We repeated this process 5 times, each time using a different part for testing.\n" "- Mean Squared Error (MSE) measures the average squared difference between predictions and actual values. Lower is better.\n" "- R-squared measures the proportion of variance in the target variable that is predictable from the feature variables. Higher is better.\n" "- Some variation in these numbers is normal and expected.\n\n" "Note: This is an estimate of how well the model might perform on new, unseen data." ) progress(1.0, desc="Completed") gr.Info("Performance estimation completed successfully.") return result except Exception as e: error_message = f"Error during performance estimation: {str(e)}\n\nPlease ensure your data is formatted correctly." gr.Error(error_message) return None def predict(df, feature_cols, target_col, task, progress=gr.Progress()): global access_token if access_token is None: gr.Warning("Please log in or register first.") return None, None try: progress(0, desc="Preparing data") X = df[feature_cols] y = df[target_col] # Split data into labeled and unlabeled mask = ~y.isnull() X_labeled = X[mask] y_labeled = y[mask] X_unlabeled = X[~mask] progress(0.2, desc="Initializing and training model") if task == "classification": # Convert target to numeric if it's categorical if y_labeled.dtype == 'object': y_labeled = pd.Categorical(y_labeled).codes if len(np.unique(y_labeled)) < 2: gr.Warning("The dataset must have at least two different categories in the target column for classification.") return None, None # Train the model model = TabPFNClassifier() model.fit(X_labeled, y_labeled) progress(0.6, desc="Making predictions") # Make predictions predictions = model.predict(X_unlabeled) probabilities = model.predict_proba(X_unlabeled) progress(0.8, desc="Preparing results") # Create a DataFrame with predictions and probabilities result_df = X_unlabeled.copy() result_df[target_col] = predictions for i, class_name in enumerate(model.classes_): result_df[f"{target_col}_prob_{class_name}"] = probabilities[:, i] elif task == "regression": # Train the model model = TabPFNRegressor() model.fit(X_labeled, y_labeled) progress(0.6, desc="Making predictions") # Make predictions predictions = model.predict(X_unlabeled) progress(0.8, desc="Preparing results") # Create a DataFrame with predictions result_df = X_unlabeled.copy() result_df[target_col] = predictions else: gr.Error("Invalid task type. Please choose either 'classification' or 'regression'.") return None, None # Save predictions to a new CSV file temp_dir = tempfile.mkdtemp() output_path = Path(temp_dir) / "predictions.csv" result_df.to_csv(output_path, index=False) progress(1.0, desc="Completed") gr.Info("Predictions completed successfully.") return gr.update(visible=True, value=str(output_path)), result_df except Exception as e: error_message = f"Error during prediction: {str(e)}\n\nPlease ensure your data is formatted correctly and try again." gr.Error(error_message) return None, None def update_column_selection(file): try: if not file: gr.Warning("Please upload a CSV file.") return None, None, None, gr.update(visible=False), gr.update(visible=False), None df = pd.read_csv(file.name) columns = df.columns.tolist() # Mark the last column as the target column df_with_target = df.copy() df_with_target.columns = [f"{col} (Target)" if i == len(columns) - 1 else col for i, col in enumerate(columns)] return gr.update(visible=True, choices=columns[:-1], value=columns[:-1]), gr.update(visible=True, choices=columns, value=columns[-1]), gr.update(visible=True, value=df_with_target.head()), gr.update(visible=True), gr.update(visible=True), df except Exception as e: error_message = f"Error reading CSV file: {str(e)}" gr.Error(error_message) return None, None, None, gr.update(visible=False), gr.update(visible=False), None def create_login_tab(): with gr.Tab("Login/Register", visible=not is_logged_in) as login_tab: with gr.Row(): with gr.Column(): login_email = gr.Textbox(label="Email", info="Enter your registered email address") login_password = gr.Textbox(label="Password", type="password", info="Enter your password") login_button = gr.Button("Login", variant="primary") with gr.Accordion("New User? Complete Registration", open=False) as register_accordion: register_email = gr.Textbox(label="Email*", info="Enter a valid email address") register_password = gr.Textbox(label="Password*", type="password", info="Enter a strong password") register_password_confirm = gr.Textbox(label="Confirm Password*", type="password", info="Re-enter your password") register_first_name = gr.Textbox(label="First Name*", info="Enter your first name") register_last_name = gr.Textbox(label="Last Name*", info="Enter your last name") register_organization = gr.Textbox(label="Where do you work? (Optional)", info="Enter your organization name") register_role = gr.Textbox(label="What is your role? (Optional)", info="Enter your job title or role") register_use_case = gr.Textbox(label="What do you want to use TabPFN for? (Optional)", info="Briefly describe your intended use case") register_contact_via_email = gr.Checkbox(label="Can we reach out to you via email to support you?", info="Check this box if you're open to receiving support emails") gr.Markdown(f"Please refer to our terms and conditions at: {TERMS_OF_SERVICE_URL}") register_tos_agreed = gr.Checkbox(label="I have read and agree to the Terms of Service*", info="You must agree to the Terms of Service to register") register_submit = gr.Button("Submit Registration", variant="primary") with gr.Accordion("Password Policy", open=False): policy_output = gr.Markdown(get_password_policy()) return login_tab, login_email, login_password, login_button, register_accordion, register_email, register_password, register_password_confirm, register_first_name, register_last_name, register_organization, register_role, register_use_case, register_contact_via_email, register_tos_agreed, register_submit def create_account_management_tab(): with gr.Tab("Account Management", visible=is_logged_in) as account_tab: gr.Markdown("## Account Management") with gr.Row(): list_datasets_button = gr.Button("List Datasets") download_all_data_button = gr.Button("Download All Data") logout_button = gr.Button("Logout") # Add Token Display Section with gr.Accordion("API Access Token", open=False) as token_accordion: token_display = gr.Textbox( label="Your API Token", info="Use this token for API access", interactive=False, show_copy_button=True, visible=False ) show_token_button = gr.Button("Show/Hide Token") datasets_table = gr.Dataframe(label="Your Datasets", visible=False) delete_dataset_input = gr.Textbox(label="Dataset UID to delete", visible=False) delete_dataset_confirm = gr.Checkbox(label="I confirm that I want to delete this dataset", visible=False) delete_dataset_button = gr.Button("Delete Dataset", visible=False) download_all_data_file = gr.File(label="Download All Data", visible=False) with gr.Accordion("Delete Account and All Data", open=False) as delete_account_accordion: delete_account_password = gr.Textbox(label="Confirm Password", type="password") delete_account_confirm = gr.Checkbox(label="I confirm that I want to delete my account") delete_account_button = gr.Button("Delete Account", variant="stop") return account_tab, list_datasets_button, download_all_data_button, logout_button, datasets_table, delete_dataset_input, delete_dataset_confirm, delete_dataset_button, download_all_data_file, delete_account_accordion, delete_account_password, delete_account_confirm, delete_account_button, token_display, show_token_button def toggle_token_display(token_visible): if token_visible: try: return gr.update(value=client.access_token, visible=True) except Exception as e: gr.Warning(f"Unable to retrieve token: {str(e)}") return gr.update(visible=False) else: return gr.update(visible=False) def create_predict_tab(): with gr.Tab("Predict", visible=is_logged_in) as predict_tab: gr.Markdown(""" ## Preparing Your Data Before uploading, please ensure your CSV file is formatted correctly: 1. The file should have a header row with column names. 2. Each subsequent row should represent one data point. 3. The last column will be treated as the target column (what you want to predict). 4. All other columns will be treated as feature columns (used for making predictions). 5. Some rows can have empty values in the target column. These are the ones the model will try to predict. Example CSV format: ``` feature1,feature2,feature3,target value1,value2,value3,category1 value4,value5,value6,category2 value7,value8,value9, ``` In this example, 'feature1', 'feature2', and 'feature3' are feature columns, and 'target' is the target column. The last row has an empty target value, indicating it's a row for prediction. Note: Make sure your target column contains only numeric values for regression or categorical labels for classification. ## How to Use This Demo **Estimate Performance**: This option helps you understand how well the model might work with your data. - It uses your labeled data (where you already know the correct values) to estimate the model's accuracy or error. - This gives you an idea of how well the model might predict values for new data. **Predict**: This option helps you predict values for new data. - It uses your labeled data to learn patterns. - Then, it predicts values for your unlabeled data (where you don't know the correct values). - The result is a downloadable file with predictions for your unlabeled data. """) file_input = gr.File(label="Upload your CSV file", file_types=[".csv"]) with gr.Row(visible=False) as column_selection: with gr.Column(): feature_cols = gr.CheckboxGroup(label="Columns to use for prediction", info="Select the columns you want to use as features for prediction") with gr.Column(): target_col = gr.Dropdown(label="Column to predict", info="Select the column you want to predict") preview_data = gr.Dataframe(label="Data Preview", visible=False) task = gr.Radio(["classification", "regression"], label="Task", info="Select the type of prediction task") with gr.Row(): estimate_button = gr.Button("Estimate Performance", visible=False) predict_button = gr.Button("Predict", visible=False) performance_output = gr.Textbox(label="Performance Estimation Results", visible=False) prediction_download = gr.File(label="Download Predictions", visible=False) prediction_table_output = gr.Dataframe(label="Preview of Prediction Results", visible=False) return predict_tab, file_input, column_selection, feature_cols, target_col, preview_data, task, estimate_button, predict_button, performance_output, prediction_download, prediction_table_output def create_interface(): with gr.Blocks(theme=gr.themes.Monochrome()) as demo: gr.Markdown("# TabPFN-V2 Demo") state = gr.State(None) token_visible_state = gr.State(False) login_tab, login_email, login_password, login_button, register_accordion, register_email, register_password, register_password_confirm, register_first_name, register_last_name, register_organization, register_role, register_use_case, register_contact_via_email, register_tos_agreed, register_submit = create_login_tab() predict_tab, file_input, column_selection, feature_cols, target_col, preview_data, task, estimate_button, predict_button, performance_output, prediction_download, prediction_table_output = create_predict_tab() account_tab, list_datasets_button, download_all_data_button, logout_button, datasets_table, delete_dataset_input, delete_dataset_confirm, delete_dataset_button, download_all_data_file, delete_account_accordion, delete_account_password, delete_account_confirm, delete_account_button, token_display, show_token_button = create_account_management_tab() # Event handlers show_token_button.click( lambda x: not x, inputs=[token_visible_state], outputs=[token_visible_state] ).then( toggle_token_display, inputs=[token_visible_state], outputs=[token_display] ) login_button.click( login, inputs=[login_email, login_password], outputs=[login_tab, predict_tab, account_tab] ) register_submit.click( register, inputs=[register_email, register_password, register_password_confirm, register_first_name, register_last_name, register_organization, register_role, register_use_case, register_contact_via_email, register_tos_agreed], outputs=[login_tab, predict_tab, account_tab] ).then( lambda: gr.update(value=False), outputs=[register_accordion] ) file_input.upload( update_column_selection, inputs=[file_input], outputs=[feature_cols, target_col, preview_data, estimate_button, predict_button, state] ) estimate_button.click( lambda: (gr.update(visible=True), gr.update(visible=False), gr.update(visible=False)), outputs=[performance_output, prediction_download, prediction_table_output] ).then( estimate_performance, inputs=[state, feature_cols, target_col, task], outputs=[performance_output], show_progress=True, ) predict_button.click( lambda: (gr.update(visible=False), gr.update(visible=True), gr.update(visible=True)), outputs=[performance_output, prediction_download, prediction_table_output] ).then( predict, inputs=[state, feature_cols, target_col, task], outputs=[prediction_download, prediction_table_output], show_progress=True, ) list_datasets_button.click( list_datasets, outputs=[datasets_table] ).then( lambda: (gr.update(visible=True), gr.update(visible=True), gr.update(visible=True)), outputs=[delete_dataset_input, delete_dataset_confirm, delete_dataset_button] ) delete_dataset_button.click( delete_dataset, inputs=[delete_dataset_input, delete_dataset_confirm], outputs=[datasets_table] ).then( lambda: (gr.update(value=""), gr.update(value=False)), outputs=[delete_dataset_input, delete_dataset_confirm] ) download_all_data_button.click( download_all_data, outputs=[download_all_data_file] ) delete_account_button.click( delete_account, inputs=[delete_account_password, delete_account_confirm], outputs=[login_tab, predict_tab, account_tab] ).then( lambda: (gr.update(value=""), gr.update(value=False), gr.update(value=False)), outputs=[delete_account_password, delete_account_confirm, delete_account_accordion] ) logout_button.click( logout, outputs=[login_tab, predict_tab, account_tab] ) return demo if __name__ == "__main__": demo = create_interface() demo.launch()