Spaces:
Running
Running
import gradio as gr | |
import pandas as pd | |
import numpy as np | |
from sklearn.model_selection import train_test_split, cross_val_score | |
from tabpfn_client import TabPFNClassifier | |
from tabpfn_client.tabpfn_common_utils import utils as common_utils | |
from tabpfn_client.client import ServiceClient | |
from password_strength import PasswordPolicy | |
import textwrap | |
import io | |
import logging | |
import sys | |
import tempfile | |
from pathlib import Path | |
import time | |
from tabpfn_client import init | |
from sklearn.metrics import mean_squared_error, r2_score | |
client = ServiceClient() | |
access_token = None | |
TERMS_OF_SERVICE_URL = "https://www.priorlabs.ai/terms-eu-en" | |
is_logged_in = False | |
class PromptAgent: | |
def indent(text: str): | |
indent_factor = 2 | |
indent_str = " " * indent_factor | |
return textwrap.indent(text, indent_str) | |
def password_req_to_policy(password_req: list[str]): | |
requirements = {} | |
for req in password_req: | |
word_part, number_part = req.split("(") | |
number = int(number_part[:-1]) | |
requirements[word_part.lower()] = number | |
return PasswordPolicy.from_names(**requirements) | |
def login(email, password): | |
global access_token | |
access_token, message = client.login(email, password) | |
if access_token: | |
client.authorize(access_token) | |
gr.Info("Login successful!") | |
return gr.update(visible=False), gr.update(visible=True), gr.update(visible=True) | |
else: | |
gr.Warning(f"Login failed: {message}") | |
return gr.update(visible=True), gr.update(visible=False), gr.update(visible=False) | |
def register(email, password, password_confirm, first_name, last_name, organization, role, use_case, contact_via_email, tos_agreed): | |
global access_token | |
if not tos_agreed: | |
gr.Warning("Registration failed: You must agree to the Terms of Service.") | |
return gr.update(visible=True), gr.update(visible=False), gr.update(visible=False) | |
is_valid, message = client.validate_email(email) | |
if not is_valid: | |
gr.Warning(f"Registration failed: Invalid email - {message}") | |
return gr.update(visible=True), gr.update(visible=False), gr.update(visible=False) | |
password_req = client.get_password_policy() | |
password_policy = PromptAgent.password_req_to_policy(password_req) | |
if len(password_policy.test(password)) != 0: | |
gr.Warning("Registration failed: Password requirements not satisfied. Please check the password policy.") | |
return gr.update(visible=True), gr.update(visible=False), gr.update(visible=False) | |
if password != password_confirm: | |
gr.Warning("Registration failed: Passwords do not match.") | |
return gr.update(visible=True), gr.update(visible=False), gr.update(visible=False) | |
validation_link = "tabpfn-2023" | |
additional_info = { | |
"first_name": first_name, | |
"last_name": last_name, | |
"company": organization, | |
"role": role, | |
"use_case": use_case, | |
"contact_via_email": contact_via_email, | |
} | |
is_created, message, access_token = client.register(email, password, password_confirm, validation_link, additional_info) | |
if is_created: | |
client.authorize(access_token) | |
gr.Info("Registration successful! Please check your email for a verification link.") | |
return gr.update(visible=False), gr.update(visible=True), gr.update(visible=True) | |
else: | |
gr.Warning(f"Registration failed: {message}") | |
return gr.update(visible=True), gr.update(visible=False), gr.update(visible=False) | |
def get_password_policy(): | |
policy = client.get_password_policy() | |
return "\n".join([f"- {req}" for req in policy]) | |
def list_datasets(): | |
try: | |
data_summary = config.g_tabpfn_config.user_auth_handler.service_client.get_data_summary() | |
# Extract relevant information from the data summary | |
datasets = data_summary.get('datasets_summary', []) | |
# Create a list to hold the formatted dataset information | |
formatted_datasets = [] | |
for dataset in datasets: | |
train_set = { | |
'Dataset Type': 'Train Set', | |
'UID': dataset['train_set_uid'], | |
'Added On': dataset['datetime_added'], | |
'X Filename': dataset['x_train_filename'], | |
'Y Filename': dataset['y_train_filename'] | |
} | |
formatted_datasets.append(train_set) | |
for test_set in dataset.get('associated_test_sets', []): | |
test_set_info = { | |
'Dataset Type': 'Test Set', | |
'UID': test_set['test_set_uid'], | |
'Added On': test_set['datetime_added'], | |
'X Filename': test_set['x_test_filename'], | |
'Y Filename': 'N/A' # Test sets don't have y_test_filename | |
} | |
formatted_datasets.append(test_set_info) | |
# Create a DataFrame from the formatted dataset information | |
df = pd.DataFrame(formatted_datasets) | |
# If the DataFrame is empty, return a message instead | |
if df.empty: | |
return gr.Dataframe(value=[["No datasets found"]], visible=True) | |
return gr.Dataframe(value=df, visible=True) | |
except Exception as e: | |
gr.Error(f"Error listing datasets: {str(e)}") | |
return gr.Dataframe(value=[["Error retrieving datasets"]], visible=True) | |
def delete_dataset(dataset_uid, confirm): | |
if not confirm: | |
gr.Warning("Please confirm the deletion by checking the confirmation box.") | |
return None | |
try: | |
deleted_uids = config.g_tabpfn_config.user_auth_handler.service_client.delete_dataset(dataset_uid) | |
gr.Info(f"Successfully deleted dataset(s): {', '.join(deleted_uids)}") | |
return list_datasets() | |
except Exception as e: | |
gr.Error(f"Error deleting dataset: {str(e)}") | |
return None | |
def delete_account(confirm_password, confirm): | |
if not confirm: | |
gr.Warning("Please confirm the account deletion by checking the confirmation box.") | |
return gr.update(visible=False), gr.update(visible=True), gr.update(visible=True) | |
try: | |
config.g_tabpfn_config.user_auth_handler.service_client.delete_user_account(confirm_password) | |
gr.Info("Account deleted successfully.") | |
# Return updates to make login tab visible and others invisible | |
return gr.update(visible=True), gr.update(visible=False), gr.update(visible=False) | |
except Exception as e: | |
gr.Error(f"Error deleting account: {str(e)}") | |
# Keep current tab visibility if there's an error | |
return gr.update(visible=False), gr.update(visible=True), gr.update(visible=True) | |
def download_all_data(): | |
try: | |
temp_dir = tempfile.mkdtemp() | |
save_path = config.g_tabpfn_config.user_auth_handler.service_client.download_all_data(temp_dir) | |
gr.Info("All data downloaded successfully.") | |
# Return the file with updated visibility | |
return gr.File(value=str(save_path), visible=True) | |
except Exception as e: | |
gr.Error(f"Error downloading data: {str(e)}") | |
return None | |
def logout(): | |
try: | |
config.g_tabpfn_config.user_auth_handler.service_client.reset_authorization() | |
gr.Info("Logged out successfully.") | |
return gr.update(visible=True), gr.update(visible=False), gr.update(visible=False) | |
except Exception as e: | |
gr.Error(f"Error during logout: {str(e)}") | |
return gr.update(visible=False), gr.update(visible=True), gr.update(visible=True) | |
from sklearn.model_selection import cross_validate | |
def estimate_performance(df, feature_cols, target_col, task, progress=gr.Progress()): | |
global access_token | |
if access_token is None: | |
gr.Warning("Please log in or register first.") | |
return None | |
try: | |
progress(0, desc="Preparing data") | |
X = df[feature_cols] | |
y = df[target_col] | |
# Remove rows with missing labels | |
mask = ~y.isnull() | |
X = X[mask] | |
y = y[mask] | |
progress(0.1, desc="Initializing model") | |
if task == "classification": | |
if y.dtype == 'object': | |
y = pd.Categorical(y).codes | |
if len(np.unique(y)) < 2: | |
gr.Warning("The dataset must have at least two different categories in the target column for classification.") | |
return None | |
model = TabPFNClassifier() | |
scoring = {'accuracy': 'accuracy'} | |
elif task == "regression": | |
model = TabPFNRegressor() | |
scoring = {'mse': 'neg_mean_squared_error', 'r2': 'r2'} | |
else: | |
gr.Error("Invalid task type. Please choose either 'classification' or 'regression'.") | |
return None | |
progress(0.2, desc="Performing cross-validation") | |
cv_results = cross_validate(model, X, y, cv=5, scoring=scoring, n_jobs=-1, return_train_score=False) | |
progress(0.9, desc="Generating report") | |
if task == "classification": | |
scores = cv_results['test_accuracy'] | |
result = ( | |
f"Average Accuracy: {np.mean(scores):.2%}\n\n" | |
"Individual Test Results:\n" | |
f"{', '.join([f'{s:.2%}' for s in scores])}\n\n" | |
"What does this mean?\n" | |
"- We tested the model's performance by dividing your data into 5 parts.\n" | |
"- For each part, we trained the model on 4 parts and tested it on the remaining part.\n" | |
"- We repeated this process 5 times, each time using a different part for testing.\n" | |
"- The average accuracy shows how often the model correctly predicted the category.\n" | |
"- Some variation in these numbers is normal and expected.\n\n" | |
"Note: This is an estimate of how well the model might perform on new, unseen data." | |
) | |
else: # regression | |
mse_scores = -cv_results['test_mse'] # Negative because sklearn returns negative MSE | |
r2_scores = cv_results['test_r2'] | |
result = ( | |
f"Average Mean Squared Error: {np.mean(mse_scores):.4f}\n" | |
f"Average R-squared: {np.mean(r2_scores):.4f}\n\n" | |
"Individual Test Results:\n" | |
f"MSE: {', '.join([f'{s:.4f}' for s in mse_scores])}\n" | |
f"R-squared: {', '.join([f'{s:.4f}' for s in r2_scores])}\n\n" | |
"What does this mean?\n" | |
"- We tested the model's performance by dividing your data into 5 parts.\n" | |
"- For each part, we trained the model on 4 parts and tested it on the remaining part.\n" | |
"- We repeated this process 5 times, each time using a different part for testing.\n" | |
"- Mean Squared Error (MSE) measures the average squared difference between predictions and actual values. Lower is better.\n" | |
"- R-squared measures the proportion of variance in the target variable that is predictable from the feature variables. Higher is better.\n" | |
"- Some variation in these numbers is normal and expected.\n\n" | |
"Note: This is an estimate of how well the model might perform on new, unseen data." | |
) | |
progress(1.0, desc="Completed") | |
gr.Info("Performance estimation completed successfully.") | |
return result | |
except Exception as e: | |
error_message = f"Error during performance estimation: {str(e)}\n\nPlease ensure your data is formatted correctly." | |
gr.Error(error_message) | |
return None | |
def predict(df, feature_cols, target_col, task, progress=gr.Progress()): | |
global access_token | |
if access_token is None: | |
gr.Warning("Please log in or register first.") | |
return None, None | |
try: | |
progress(0, desc="Preparing data") | |
X = df[feature_cols] | |
y = df[target_col] | |
# Split data into labeled and unlabeled | |
mask = ~y.isnull() | |
X_labeled = X[mask] | |
y_labeled = y[mask] | |
X_unlabeled = X[~mask] | |
progress(0.2, desc="Initializing and training model") | |
if task == "classification": | |
# Convert target to numeric if it's categorical | |
if y_labeled.dtype == 'object': | |
y_labeled = pd.Categorical(y_labeled).codes | |
if len(np.unique(y_labeled)) < 2: | |
gr.Warning("The dataset must have at least two different categories in the target column for classification.") | |
return None, None | |
# Train the model | |
model = TabPFNClassifier() | |
model.fit(X_labeled, y_labeled) | |
progress(0.6, desc="Making predictions") | |
# Make predictions | |
predictions = model.predict(X_unlabeled) | |
probabilities = model.predict_proba(X_unlabeled) | |
progress(0.8, desc="Preparing results") | |
# Create a DataFrame with predictions and probabilities | |
result_df = X_unlabeled.copy() | |
result_df[target_col] = predictions | |
for i, class_name in enumerate(model.classes_): | |
result_df[f"{target_col}_prob_{class_name}"] = probabilities[:, i] | |
elif task == "regression": | |
# Train the model | |
model = TabPFNRegressor() | |
model.fit(X_labeled, y_labeled) | |
progress(0.6, desc="Making predictions") | |
# Make predictions | |
predictions = model.predict(X_unlabeled) | |
progress(0.8, desc="Preparing results") | |
# Create a DataFrame with predictions | |
result_df = X_unlabeled.copy() | |
result_df[target_col] = predictions | |
else: | |
gr.Error("Invalid task type. Please choose either 'classification' or 'regression'.") | |
return None, None | |
# Save predictions to a new CSV file | |
temp_dir = tempfile.mkdtemp() | |
output_path = Path(temp_dir) / "predictions.csv" | |
result_df.to_csv(output_path, index=False) | |
progress(1.0, desc="Completed") | |
gr.Info("Predictions completed successfully.") | |
return gr.update(visible=True, value=str(output_path)), result_df | |
except Exception as e: | |
error_message = f"Error during prediction: {str(e)}\n\nPlease ensure your data is formatted correctly and try again." | |
gr.Error(error_message) | |
return None, None | |
def update_column_selection(file): | |
try: | |
if not file: | |
gr.Warning("Please upload a CSV file.") | |
return None, None, None, gr.update(visible=False), gr.update(visible=False), None | |
df = pd.read_csv(file.name) | |
columns = df.columns.tolist() | |
# Mark the last column as the target column | |
df_with_target = df.copy() | |
df_with_target.columns = [f"{col} (Target)" if i == len(columns) - 1 else col for i, col in enumerate(columns)] | |
return gr.update(visible=True, choices=columns[:-1], value=columns[:-1]), gr.update(visible=True, choices=columns, value=columns[-1]), gr.update(visible=True, value=df_with_target.head()), gr.update(visible=True), gr.update(visible=True), df | |
except Exception as e: | |
error_message = f"Error reading CSV file: {str(e)}" | |
gr.Error(error_message) | |
return None, None, None, gr.update(visible=False), gr.update(visible=False), None | |
def create_login_tab(): | |
with gr.Tab("Login/Register", visible=not is_logged_in) as login_tab: | |
with gr.Row(): | |
with gr.Column(): | |
login_email = gr.Textbox(label="Email", info="Enter your registered email address") | |
login_password = gr.Textbox(label="Password", type="password", info="Enter your password") | |
login_button = gr.Button("Login", variant="primary") | |
with gr.Accordion("New User? Complete Registration", open=False) as register_accordion: | |
register_email = gr.Textbox(label="Email*", info="Enter a valid email address") | |
register_password = gr.Textbox(label="Password*", type="password", info="Enter a strong password") | |
register_password_confirm = gr.Textbox(label="Confirm Password*", type="password", info="Re-enter your password") | |
register_first_name = gr.Textbox(label="First Name*", info="Enter your first name") | |
register_last_name = gr.Textbox(label="Last Name*", info="Enter your last name") | |
register_organization = gr.Textbox(label="Where do you work? (Optional)", info="Enter your organization name") | |
register_role = gr.Textbox(label="What is your role? (Optional)", info="Enter your job title or role") | |
register_use_case = gr.Textbox(label="What do you want to use TabPFN for? (Optional)", info="Briefly describe your intended use case") | |
register_contact_via_email = gr.Checkbox(label="Can we reach out to you via email to support you?", info="Check this box if you're open to receiving support emails") | |
gr.Markdown(f"Please refer to our terms and conditions at: {TERMS_OF_SERVICE_URL}") | |
register_tos_agreed = gr.Checkbox(label="I have read and agree to the Terms of Service*", info="You must agree to the Terms of Service to register") | |
register_submit = gr.Button("Submit Registration", variant="primary") | |
with gr.Accordion("Password Policy", open=False): | |
policy_output = gr.Markdown(get_password_policy()) | |
return login_tab, login_email, login_password, login_button, register_accordion, register_email, register_password, register_password_confirm, register_first_name, register_last_name, register_organization, register_role, register_use_case, register_contact_via_email, register_tos_agreed, register_submit | |
def create_account_management_tab(): | |
with gr.Tab("Account Management", visible=is_logged_in) as account_tab: | |
gr.Markdown("## Account Management") | |
with gr.Row(): | |
list_datasets_button = gr.Button("List Datasets") | |
download_all_data_button = gr.Button("Download All Data") | |
logout_button = gr.Button("Logout") | |
# Add Token Display Section | |
with gr.Accordion("API Access Token", open=False) as token_accordion: | |
token_display = gr.Textbox( | |
label="Your API Token", | |
info="Use this token for API access", | |
interactive=False, | |
show_copy_button=True, | |
visible=False | |
) | |
show_token_button = gr.Button("Show/Hide Token") | |
datasets_table = gr.Dataframe(label="Your Datasets", visible=False) | |
delete_dataset_input = gr.Textbox(label="Dataset UID to delete", visible=False) | |
delete_dataset_confirm = gr.Checkbox(label="I confirm that I want to delete this dataset", visible=False) | |
delete_dataset_button = gr.Button("Delete Dataset", visible=False) | |
download_all_data_file = gr.File(label="Download All Data", visible=False) | |
with gr.Accordion("Delete Account and All Data", open=False) as delete_account_accordion: | |
delete_account_password = gr.Textbox(label="Confirm Password", type="password") | |
delete_account_confirm = gr.Checkbox(label="I confirm that I want to delete my account") | |
delete_account_button = gr.Button("Delete Account", variant="stop") | |
return account_tab, list_datasets_button, download_all_data_button, logout_button, datasets_table, delete_dataset_input, delete_dataset_confirm, delete_dataset_button, download_all_data_file, delete_account_accordion, delete_account_password, delete_account_confirm, delete_account_button, token_display, show_token_button | |
def toggle_token_display(token_visible): | |
if token_visible: | |
try: | |
return gr.update(value=client.access_token, visible=True) | |
except Exception as e: | |
gr.Warning(f"Unable to retrieve token: {str(e)}") | |
return gr.update(visible=False) | |
else: | |
return gr.update(visible=False) | |
def create_predict_tab(): | |
with gr.Tab("Predict", visible=is_logged_in) as predict_tab: | |
gr.Markdown(""" | |
## Preparing Your Data | |
Before uploading, please ensure your CSV file is formatted correctly: | |
1. The file should have a header row with column names. | |
2. Each subsequent row should represent one data point. | |
3. The last column will be treated as the target column (what you want to predict). | |
4. All other columns will be treated as feature columns (used for making predictions). | |
5. Some rows can have empty values in the target column. These are the ones the model will try to predict. | |
Example CSV format: | |
``` | |
feature1,feature2,feature3,target | |
value1,value2,value3,category1 | |
value4,value5,value6,category2 | |
value7,value8,value9, | |
``` | |
In this example, 'feature1', 'feature2', and 'feature3' are feature columns, and 'target' is the target column. | |
The last row has an empty target value, indicating it's a row for prediction. | |
Note: Make sure your target column contains only numeric values for regression or categorical labels for classification. | |
## How to Use This Demo | |
**Estimate Performance**: This option helps you understand how well the model might work with your data. | |
- It uses your labeled data (where you already know the correct values) to estimate the model's accuracy or error. | |
- This gives you an idea of how well the model might predict values for new data. | |
**Predict**: This option helps you predict values for new data. | |
- It uses your labeled data to learn patterns. | |
- Then, it predicts values for your unlabeled data (where you don't know the correct values). | |
- The result is a downloadable file with predictions for your unlabeled data. | |
""") | |
file_input = gr.File(label="Upload your CSV file", file_types=[".csv"]) | |
with gr.Row(visible=False) as column_selection: | |
with gr.Column(): | |
feature_cols = gr.CheckboxGroup(label="Columns to use for prediction", info="Select the columns you want to use as features for prediction") | |
with gr.Column(): | |
target_col = gr.Dropdown(label="Column to predict", info="Select the column you want to predict") | |
preview_data = gr.Dataframe(label="Data Preview", visible=False) | |
task = gr.Radio(["classification", "regression"], label="Task", info="Select the type of prediction task") | |
with gr.Row(): | |
estimate_button = gr.Button("Estimate Performance", visible=False) | |
predict_button = gr.Button("Predict", visible=False) | |
performance_output = gr.Textbox(label="Performance Estimation Results", visible=False) | |
prediction_download = gr.File(label="Download Predictions", visible=False) | |
prediction_table_output = gr.Dataframe(label="Preview of Prediction Results", visible=False) | |
return predict_tab, file_input, column_selection, feature_cols, target_col, preview_data, task, estimate_button, predict_button, performance_output, prediction_download, prediction_table_output | |
def create_interface(): | |
with gr.Blocks(theme=gr.themes.Monochrome()) as demo: | |
gr.Markdown("# TabPFN-V2 Demo") | |
state = gr.State(None) | |
token_visible_state = gr.State(False) | |
login_tab, login_email, login_password, login_button, register_accordion, register_email, register_password, register_password_confirm, register_first_name, register_last_name, register_organization, register_role, register_use_case, register_contact_via_email, register_tos_agreed, register_submit = create_login_tab() | |
predict_tab, file_input, column_selection, feature_cols, target_col, preview_data, task, estimate_button, predict_button, performance_output, prediction_download, prediction_table_output = create_predict_tab() | |
account_tab, list_datasets_button, download_all_data_button, logout_button, datasets_table, delete_dataset_input, delete_dataset_confirm, delete_dataset_button, download_all_data_file, delete_account_accordion, delete_account_password, delete_account_confirm, delete_account_button, token_display, show_token_button = create_account_management_tab() | |
# Event handlers | |
show_token_button.click( | |
lambda x: not x, | |
inputs=[token_visible_state], | |
outputs=[token_visible_state] | |
).then( | |
toggle_token_display, | |
inputs=[token_visible_state], | |
outputs=[token_display] | |
) | |
login_button.click( | |
login, | |
inputs=[login_email, login_password], | |
outputs=[login_tab, predict_tab, account_tab] | |
) | |
register_submit.click( | |
register, | |
inputs=[register_email, register_password, register_password_confirm, | |
register_first_name, register_last_name, register_organization, | |
register_role, register_use_case, register_contact_via_email, register_tos_agreed], | |
outputs=[login_tab, predict_tab, account_tab] | |
).then( | |
lambda: gr.update(value=False), | |
outputs=[register_accordion] | |
) | |
file_input.upload( | |
update_column_selection, | |
inputs=[file_input], | |
outputs=[feature_cols, target_col, preview_data, estimate_button, predict_button, state] | |
) | |
estimate_button.click( | |
lambda: (gr.update(visible=True), gr.update(visible=False), gr.update(visible=False)), | |
outputs=[performance_output, prediction_download, prediction_table_output] | |
).then( | |
estimate_performance, | |
inputs=[state, feature_cols, target_col, task], | |
outputs=[performance_output], | |
show_progress=True, | |
) | |
predict_button.click( | |
lambda: (gr.update(visible=False), gr.update(visible=True), gr.update(visible=True)), | |
outputs=[performance_output, prediction_download, prediction_table_output] | |
).then( | |
predict, | |
inputs=[state, feature_cols, target_col, task], | |
outputs=[prediction_download, prediction_table_output], | |
show_progress=True, | |
) | |
list_datasets_button.click( | |
list_datasets, | |
outputs=[datasets_table] | |
).then( | |
lambda: (gr.update(visible=True), gr.update(visible=True), gr.update(visible=True)), | |
outputs=[delete_dataset_input, delete_dataset_confirm, delete_dataset_button] | |
) | |
delete_dataset_button.click( | |
delete_dataset, | |
inputs=[delete_dataset_input, delete_dataset_confirm], | |
outputs=[datasets_table] | |
).then( | |
lambda: (gr.update(value=""), gr.update(value=False)), | |
outputs=[delete_dataset_input, delete_dataset_confirm] | |
) | |
download_all_data_button.click( | |
download_all_data, | |
outputs=[download_all_data_file] | |
) | |
delete_account_button.click( | |
delete_account, | |
inputs=[delete_account_password, delete_account_confirm], | |
outputs=[login_tab, predict_tab, account_tab] | |
).then( | |
lambda: (gr.update(value=""), gr.update(value=False), gr.update(value=False)), | |
outputs=[delete_account_password, delete_account_confirm, delete_account_accordion] | |
) | |
logout_button.click( | |
logout, | |
outputs=[login_tab, predict_tab, account_tab] | |
) | |
return demo | |
if __name__ == "__main__": | |
demo = create_interface() | |
demo.launch() |