noahho's picture
Add token retrieval
fffebed
import gradio as gr
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split, cross_val_score
from tabpfn_client import TabPFNClassifier
from tabpfn_client.tabpfn_common_utils import utils as common_utils
from tabpfn_client.client import ServiceClient
from password_strength import PasswordPolicy
import textwrap
import io
import logging
import sys
import tempfile
from pathlib import Path
import time
from tabpfn_client import init
from sklearn.metrics import mean_squared_error, r2_score
client = ServiceClient()
access_token = None
TERMS_OF_SERVICE_URL = "https://www.priorlabs.ai/terms-eu-en"
is_logged_in = False
class PromptAgent:
@staticmethod
def indent(text: str):
indent_factor = 2
indent_str = " " * indent_factor
return textwrap.indent(text, indent_str)
@staticmethod
def password_req_to_policy(password_req: list[str]):
requirements = {}
for req in password_req:
word_part, number_part = req.split("(")
number = int(number_part[:-1])
requirements[word_part.lower()] = number
return PasswordPolicy.from_names(**requirements)
def login(email, password):
global access_token
access_token, message = client.login(email, password)
if access_token:
client.authorize(access_token)
gr.Info("Login successful!")
return gr.update(visible=False), gr.update(visible=True), gr.update(visible=True)
else:
gr.Warning(f"Login failed: {message}")
return gr.update(visible=True), gr.update(visible=False), gr.update(visible=False)
def register(email, password, password_confirm, first_name, last_name, organization, role, use_case, contact_via_email, tos_agreed):
global access_token
if not tos_agreed:
gr.Warning("Registration failed: You must agree to the Terms of Service.")
return gr.update(visible=True), gr.update(visible=False), gr.update(visible=False)
is_valid, message = client.validate_email(email)
if not is_valid:
gr.Warning(f"Registration failed: Invalid email - {message}")
return gr.update(visible=True), gr.update(visible=False), gr.update(visible=False)
password_req = client.get_password_policy()
password_policy = PromptAgent.password_req_to_policy(password_req)
if len(password_policy.test(password)) != 0:
gr.Warning("Registration failed: Password requirements not satisfied. Please check the password policy.")
return gr.update(visible=True), gr.update(visible=False), gr.update(visible=False)
if password != password_confirm:
gr.Warning("Registration failed: Passwords do not match.")
return gr.update(visible=True), gr.update(visible=False), gr.update(visible=False)
validation_link = "tabpfn-2023"
additional_info = {
"first_name": first_name,
"last_name": last_name,
"company": organization,
"role": role,
"use_case": use_case,
"contact_via_email": contact_via_email,
}
is_created, message, access_token = client.register(email, password, password_confirm, validation_link, additional_info)
if is_created:
client.authorize(access_token)
gr.Info("Registration successful! Please check your email for a verification link.")
return gr.update(visible=False), gr.update(visible=True), gr.update(visible=True)
else:
gr.Warning(f"Registration failed: {message}")
return gr.update(visible=True), gr.update(visible=False), gr.update(visible=False)
def get_password_policy():
policy = client.get_password_policy()
return "\n".join([f"- {req}" for req in policy])
def list_datasets():
try:
data_summary = config.g_tabpfn_config.user_auth_handler.service_client.get_data_summary()
# Extract relevant information from the data summary
datasets = data_summary.get('datasets_summary', [])
# Create a list to hold the formatted dataset information
formatted_datasets = []
for dataset in datasets:
train_set = {
'Dataset Type': 'Train Set',
'UID': dataset['train_set_uid'],
'Added On': dataset['datetime_added'],
'X Filename': dataset['x_train_filename'],
'Y Filename': dataset['y_train_filename']
}
formatted_datasets.append(train_set)
for test_set in dataset.get('associated_test_sets', []):
test_set_info = {
'Dataset Type': 'Test Set',
'UID': test_set['test_set_uid'],
'Added On': test_set['datetime_added'],
'X Filename': test_set['x_test_filename'],
'Y Filename': 'N/A' # Test sets don't have y_test_filename
}
formatted_datasets.append(test_set_info)
# Create a DataFrame from the formatted dataset information
df = pd.DataFrame(formatted_datasets)
# If the DataFrame is empty, return a message instead
if df.empty:
return gr.Dataframe(value=[["No datasets found"]], visible=True)
return gr.Dataframe(value=df, visible=True)
except Exception as e:
gr.Error(f"Error listing datasets: {str(e)}")
return gr.Dataframe(value=[["Error retrieving datasets"]], visible=True)
def delete_dataset(dataset_uid, confirm):
if not confirm:
gr.Warning("Please confirm the deletion by checking the confirmation box.")
return None
try:
deleted_uids = config.g_tabpfn_config.user_auth_handler.service_client.delete_dataset(dataset_uid)
gr.Info(f"Successfully deleted dataset(s): {', '.join(deleted_uids)}")
return list_datasets()
except Exception as e:
gr.Error(f"Error deleting dataset: {str(e)}")
return None
def delete_account(confirm_password, confirm):
if not confirm:
gr.Warning("Please confirm the account deletion by checking the confirmation box.")
return gr.update(visible=False), gr.update(visible=True), gr.update(visible=True)
try:
config.g_tabpfn_config.user_auth_handler.service_client.delete_user_account(confirm_password)
gr.Info("Account deleted successfully.")
# Return updates to make login tab visible and others invisible
return gr.update(visible=True), gr.update(visible=False), gr.update(visible=False)
except Exception as e:
gr.Error(f"Error deleting account: {str(e)}")
# Keep current tab visibility if there's an error
return gr.update(visible=False), gr.update(visible=True), gr.update(visible=True)
def download_all_data():
try:
temp_dir = tempfile.mkdtemp()
save_path = config.g_tabpfn_config.user_auth_handler.service_client.download_all_data(temp_dir)
gr.Info("All data downloaded successfully.")
# Return the file with updated visibility
return gr.File(value=str(save_path), visible=True)
except Exception as e:
gr.Error(f"Error downloading data: {str(e)}")
return None
def logout():
try:
config.g_tabpfn_config.user_auth_handler.service_client.reset_authorization()
gr.Info("Logged out successfully.")
return gr.update(visible=True), gr.update(visible=False), gr.update(visible=False)
except Exception as e:
gr.Error(f"Error during logout: {str(e)}")
return gr.update(visible=False), gr.update(visible=True), gr.update(visible=True)
from sklearn.model_selection import cross_validate
def estimate_performance(df, feature_cols, target_col, task, progress=gr.Progress()):
global access_token
if access_token is None:
gr.Warning("Please log in or register first.")
return None
try:
progress(0, desc="Preparing data")
X = df[feature_cols]
y = df[target_col]
# Remove rows with missing labels
mask = ~y.isnull()
X = X[mask]
y = y[mask]
progress(0.1, desc="Initializing model")
if task == "classification":
if y.dtype == 'object':
y = pd.Categorical(y).codes
if len(np.unique(y)) < 2:
gr.Warning("The dataset must have at least two different categories in the target column for classification.")
return None
model = TabPFNClassifier()
scoring = {'accuracy': 'accuracy'}
elif task == "regression":
model = TabPFNRegressor()
scoring = {'mse': 'neg_mean_squared_error', 'r2': 'r2'}
else:
gr.Error("Invalid task type. Please choose either 'classification' or 'regression'.")
return None
progress(0.2, desc="Performing cross-validation")
cv_results = cross_validate(model, X, y, cv=5, scoring=scoring, n_jobs=-1, return_train_score=False)
progress(0.9, desc="Generating report")
if task == "classification":
scores = cv_results['test_accuracy']
result = (
f"Average Accuracy: {np.mean(scores):.2%}\n\n"
"Individual Test Results:\n"
f"{', '.join([f'{s:.2%}' for s in scores])}\n\n"
"What does this mean?\n"
"- We tested the model's performance by dividing your data into 5 parts.\n"
"- For each part, we trained the model on 4 parts and tested it on the remaining part.\n"
"- We repeated this process 5 times, each time using a different part for testing.\n"
"- The average accuracy shows how often the model correctly predicted the category.\n"
"- Some variation in these numbers is normal and expected.\n\n"
"Note: This is an estimate of how well the model might perform on new, unseen data."
)
else: # regression
mse_scores = -cv_results['test_mse'] # Negative because sklearn returns negative MSE
r2_scores = cv_results['test_r2']
result = (
f"Average Mean Squared Error: {np.mean(mse_scores):.4f}\n"
f"Average R-squared: {np.mean(r2_scores):.4f}\n\n"
"Individual Test Results:\n"
f"MSE: {', '.join([f'{s:.4f}' for s in mse_scores])}\n"
f"R-squared: {', '.join([f'{s:.4f}' for s in r2_scores])}\n\n"
"What does this mean?\n"
"- We tested the model's performance by dividing your data into 5 parts.\n"
"- For each part, we trained the model on 4 parts and tested it on the remaining part.\n"
"- We repeated this process 5 times, each time using a different part for testing.\n"
"- Mean Squared Error (MSE) measures the average squared difference between predictions and actual values. Lower is better.\n"
"- R-squared measures the proportion of variance in the target variable that is predictable from the feature variables. Higher is better.\n"
"- Some variation in these numbers is normal and expected.\n\n"
"Note: This is an estimate of how well the model might perform on new, unseen data."
)
progress(1.0, desc="Completed")
gr.Info("Performance estimation completed successfully.")
return result
except Exception as e:
error_message = f"Error during performance estimation: {str(e)}\n\nPlease ensure your data is formatted correctly."
gr.Error(error_message)
return None
def predict(df, feature_cols, target_col, task, progress=gr.Progress()):
global access_token
if access_token is None:
gr.Warning("Please log in or register first.")
return None, None
try:
progress(0, desc="Preparing data")
X = df[feature_cols]
y = df[target_col]
# Split data into labeled and unlabeled
mask = ~y.isnull()
X_labeled = X[mask]
y_labeled = y[mask]
X_unlabeled = X[~mask]
progress(0.2, desc="Initializing and training model")
if task == "classification":
# Convert target to numeric if it's categorical
if y_labeled.dtype == 'object':
y_labeled = pd.Categorical(y_labeled).codes
if len(np.unique(y_labeled)) < 2:
gr.Warning("The dataset must have at least two different categories in the target column for classification.")
return None, None
# Train the model
model = TabPFNClassifier()
model.fit(X_labeled, y_labeled)
progress(0.6, desc="Making predictions")
# Make predictions
predictions = model.predict(X_unlabeled)
probabilities = model.predict_proba(X_unlabeled)
progress(0.8, desc="Preparing results")
# Create a DataFrame with predictions and probabilities
result_df = X_unlabeled.copy()
result_df[target_col] = predictions
for i, class_name in enumerate(model.classes_):
result_df[f"{target_col}_prob_{class_name}"] = probabilities[:, i]
elif task == "regression":
# Train the model
model = TabPFNRegressor()
model.fit(X_labeled, y_labeled)
progress(0.6, desc="Making predictions")
# Make predictions
predictions = model.predict(X_unlabeled)
progress(0.8, desc="Preparing results")
# Create a DataFrame with predictions
result_df = X_unlabeled.copy()
result_df[target_col] = predictions
else:
gr.Error("Invalid task type. Please choose either 'classification' or 'regression'.")
return None, None
# Save predictions to a new CSV file
temp_dir = tempfile.mkdtemp()
output_path = Path(temp_dir) / "predictions.csv"
result_df.to_csv(output_path, index=False)
progress(1.0, desc="Completed")
gr.Info("Predictions completed successfully.")
return gr.update(visible=True, value=str(output_path)), result_df
except Exception as e:
error_message = f"Error during prediction: {str(e)}\n\nPlease ensure your data is formatted correctly and try again."
gr.Error(error_message)
return None, None
def update_column_selection(file):
try:
if not file:
gr.Warning("Please upload a CSV file.")
return None, None, None, gr.update(visible=False), gr.update(visible=False), None
df = pd.read_csv(file.name)
columns = df.columns.tolist()
# Mark the last column as the target column
df_with_target = df.copy()
df_with_target.columns = [f"{col} (Target)" if i == len(columns) - 1 else col for i, col in enumerate(columns)]
return gr.update(visible=True, choices=columns[:-1], value=columns[:-1]), gr.update(visible=True, choices=columns, value=columns[-1]), gr.update(visible=True, value=df_with_target.head()), gr.update(visible=True), gr.update(visible=True), df
except Exception as e:
error_message = f"Error reading CSV file: {str(e)}"
gr.Error(error_message)
return None, None, None, gr.update(visible=False), gr.update(visible=False), None
def create_login_tab():
with gr.Tab("Login/Register", visible=not is_logged_in) as login_tab:
with gr.Row():
with gr.Column():
login_email = gr.Textbox(label="Email", info="Enter your registered email address")
login_password = gr.Textbox(label="Password", type="password", info="Enter your password")
login_button = gr.Button("Login", variant="primary")
with gr.Accordion("New User? Complete Registration", open=False) as register_accordion:
register_email = gr.Textbox(label="Email*", info="Enter a valid email address")
register_password = gr.Textbox(label="Password*", type="password", info="Enter a strong password")
register_password_confirm = gr.Textbox(label="Confirm Password*", type="password", info="Re-enter your password")
register_first_name = gr.Textbox(label="First Name*", info="Enter your first name")
register_last_name = gr.Textbox(label="Last Name*", info="Enter your last name")
register_organization = gr.Textbox(label="Where do you work? (Optional)", info="Enter your organization name")
register_role = gr.Textbox(label="What is your role? (Optional)", info="Enter your job title or role")
register_use_case = gr.Textbox(label="What do you want to use TabPFN for? (Optional)", info="Briefly describe your intended use case")
register_contact_via_email = gr.Checkbox(label="Can we reach out to you via email to support you?", info="Check this box if you're open to receiving support emails")
gr.Markdown(f"Please refer to our terms and conditions at: {TERMS_OF_SERVICE_URL}")
register_tos_agreed = gr.Checkbox(label="I have read and agree to the Terms of Service*", info="You must agree to the Terms of Service to register")
register_submit = gr.Button("Submit Registration", variant="primary")
with gr.Accordion("Password Policy", open=False):
policy_output = gr.Markdown(get_password_policy())
return login_tab, login_email, login_password, login_button, register_accordion, register_email, register_password, register_password_confirm, register_first_name, register_last_name, register_organization, register_role, register_use_case, register_contact_via_email, register_tos_agreed, register_submit
def create_account_management_tab():
with gr.Tab("Account Management", visible=is_logged_in) as account_tab:
gr.Markdown("## Account Management")
with gr.Row():
list_datasets_button = gr.Button("List Datasets")
download_all_data_button = gr.Button("Download All Data")
logout_button = gr.Button("Logout")
# Add Token Display Section
with gr.Accordion("API Access Token", open=False) as token_accordion:
token_display = gr.Textbox(
label="Your API Token",
info="Use this token for API access",
interactive=False,
show_copy_button=True,
visible=False
)
show_token_button = gr.Button("Show/Hide Token")
datasets_table = gr.Dataframe(label="Your Datasets", visible=False)
delete_dataset_input = gr.Textbox(label="Dataset UID to delete", visible=False)
delete_dataset_confirm = gr.Checkbox(label="I confirm that I want to delete this dataset", visible=False)
delete_dataset_button = gr.Button("Delete Dataset", visible=False)
download_all_data_file = gr.File(label="Download All Data", visible=False)
with gr.Accordion("Delete Account and All Data", open=False) as delete_account_accordion:
delete_account_password = gr.Textbox(label="Confirm Password", type="password")
delete_account_confirm = gr.Checkbox(label="I confirm that I want to delete my account")
delete_account_button = gr.Button("Delete Account", variant="stop")
return account_tab, list_datasets_button, download_all_data_button, logout_button, datasets_table, delete_dataset_input, delete_dataset_confirm, delete_dataset_button, download_all_data_file, delete_account_accordion, delete_account_password, delete_account_confirm, delete_account_button, token_display, show_token_button
def toggle_token_display(token_visible):
if token_visible:
try:
return gr.update(value=client.access_token, visible=True)
except Exception as e:
gr.Warning(f"Unable to retrieve token: {str(e)}")
return gr.update(visible=False)
else:
return gr.update(visible=False)
def create_predict_tab():
with gr.Tab("Predict", visible=is_logged_in) as predict_tab:
gr.Markdown("""
## Preparing Your Data
Before uploading, please ensure your CSV file is formatted correctly:
1. The file should have a header row with column names.
2. Each subsequent row should represent one data point.
3. The last column will be treated as the target column (what you want to predict).
4. All other columns will be treated as feature columns (used for making predictions).
5. Some rows can have empty values in the target column. These are the ones the model will try to predict.
Example CSV format:
```
feature1,feature2,feature3,target
value1,value2,value3,category1
value4,value5,value6,category2
value7,value8,value9,
```
In this example, 'feature1', 'feature2', and 'feature3' are feature columns, and 'target' is the target column.
The last row has an empty target value, indicating it's a row for prediction.
Note: Make sure your target column contains only numeric values for regression or categorical labels for classification.
## How to Use This Demo
**Estimate Performance**: This option helps you understand how well the model might work with your data.
- It uses your labeled data (where you already know the correct values) to estimate the model's accuracy or error.
- This gives you an idea of how well the model might predict values for new data.
**Predict**: This option helps you predict values for new data.
- It uses your labeled data to learn patterns.
- Then, it predicts values for your unlabeled data (where you don't know the correct values).
- The result is a downloadable file with predictions for your unlabeled data.
""")
file_input = gr.File(label="Upload your CSV file", file_types=[".csv"])
with gr.Row(visible=False) as column_selection:
with gr.Column():
feature_cols = gr.CheckboxGroup(label="Columns to use for prediction", info="Select the columns you want to use as features for prediction")
with gr.Column():
target_col = gr.Dropdown(label="Column to predict", info="Select the column you want to predict")
preview_data = gr.Dataframe(label="Data Preview", visible=False)
task = gr.Radio(["classification", "regression"], label="Task", info="Select the type of prediction task")
with gr.Row():
estimate_button = gr.Button("Estimate Performance", visible=False)
predict_button = gr.Button("Predict", visible=False)
performance_output = gr.Textbox(label="Performance Estimation Results", visible=False)
prediction_download = gr.File(label="Download Predictions", visible=False)
prediction_table_output = gr.Dataframe(label="Preview of Prediction Results", visible=False)
return predict_tab, file_input, column_selection, feature_cols, target_col, preview_data, task, estimate_button, predict_button, performance_output, prediction_download, prediction_table_output
def create_interface():
with gr.Blocks(theme=gr.themes.Monochrome()) as demo:
gr.Markdown("# TabPFN-V2 Demo")
state = gr.State(None)
token_visible_state = gr.State(False)
login_tab, login_email, login_password, login_button, register_accordion, register_email, register_password, register_password_confirm, register_first_name, register_last_name, register_organization, register_role, register_use_case, register_contact_via_email, register_tos_agreed, register_submit = create_login_tab()
predict_tab, file_input, column_selection, feature_cols, target_col, preview_data, task, estimate_button, predict_button, performance_output, prediction_download, prediction_table_output = create_predict_tab()
account_tab, list_datasets_button, download_all_data_button, logout_button, datasets_table, delete_dataset_input, delete_dataset_confirm, delete_dataset_button, download_all_data_file, delete_account_accordion, delete_account_password, delete_account_confirm, delete_account_button, token_display, show_token_button = create_account_management_tab()
# Event handlers
show_token_button.click(
lambda x: not x,
inputs=[token_visible_state],
outputs=[token_visible_state]
).then(
toggle_token_display,
inputs=[token_visible_state],
outputs=[token_display]
)
login_button.click(
login,
inputs=[login_email, login_password],
outputs=[login_tab, predict_tab, account_tab]
)
register_submit.click(
register,
inputs=[register_email, register_password, register_password_confirm,
register_first_name, register_last_name, register_organization,
register_role, register_use_case, register_contact_via_email, register_tos_agreed],
outputs=[login_tab, predict_tab, account_tab]
).then(
lambda: gr.update(value=False),
outputs=[register_accordion]
)
file_input.upload(
update_column_selection,
inputs=[file_input],
outputs=[feature_cols, target_col, preview_data, estimate_button, predict_button, state]
)
estimate_button.click(
lambda: (gr.update(visible=True), gr.update(visible=False), gr.update(visible=False)),
outputs=[performance_output, prediction_download, prediction_table_output]
).then(
estimate_performance,
inputs=[state, feature_cols, target_col, task],
outputs=[performance_output],
show_progress=True,
)
predict_button.click(
lambda: (gr.update(visible=False), gr.update(visible=True), gr.update(visible=True)),
outputs=[performance_output, prediction_download, prediction_table_output]
).then(
predict,
inputs=[state, feature_cols, target_col, task],
outputs=[prediction_download, prediction_table_output],
show_progress=True,
)
list_datasets_button.click(
list_datasets,
outputs=[datasets_table]
).then(
lambda: (gr.update(visible=True), gr.update(visible=True), gr.update(visible=True)),
outputs=[delete_dataset_input, delete_dataset_confirm, delete_dataset_button]
)
delete_dataset_button.click(
delete_dataset,
inputs=[delete_dataset_input, delete_dataset_confirm],
outputs=[datasets_table]
).then(
lambda: (gr.update(value=""), gr.update(value=False)),
outputs=[delete_dataset_input, delete_dataset_confirm]
)
download_all_data_button.click(
download_all_data,
outputs=[download_all_data_file]
)
delete_account_button.click(
delete_account,
inputs=[delete_account_password, delete_account_confirm],
outputs=[login_tab, predict_tab, account_tab]
).then(
lambda: (gr.update(value=""), gr.update(value=False), gr.update(value=False)),
outputs=[delete_account_password, delete_account_confirm, delete_account_accordion]
)
logout_button.click(
logout,
outputs=[login_tab, predict_tab, account_tab]
)
return demo
if __name__ == "__main__":
demo = create_interface()
demo.launch()