import torch import torchvision.transforms as transforms import torchvision.models as models import gradio as gr import numpy as np import tensorflow as tf from PIL import Image from sklearn.preprocessing import StandardScaler import joblib import os # Disable GPU for TensorFlow to avoid CUDA conflicts os.environ["CUDA_VISIBLE_DEVICES"] = "-1" # Set PyTorch device device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Load trained ViT model (PyTorch) vit_model = models.vit_b_16(weights="DEFAULT") # Fixed deprecated 'pretrained' vit_model.heads = torch.nn.Linear(in_features=768, out_features=2) # Binary classification # Load ViT model weights (if available) vit_model_path = "vit_bc.pth" if os.path.exists(vit_model_path): vit_model.load_state_dict(torch.load(vit_model_path, map_location=device)) vit_model.to(device) vit_model.eval() # Define image transformations for ViT transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) # Class labels class_names = ["Benign", "Malignant"] # Load trained Neural Network model (TensorFlow/Keras) nn_model_path = "my_NN_BC_model.keras" nn_model = tf.keras.models.load_model(nn_model_path) if os.path.exists(nn_model_path): try: nn_model = tf.keras.models.load_model(nn_model_path) except Exception as e: print(f"Error loading NN model: {e}") # Load scaler for feature normalization scaler_path = "nn_bc_scaler.pkl" scaler = joblib.load(scaler_path) if os.path.exists(scaler_path) else None # Feature names feature_names = [ "Mean Radius", "Mean Texture", "Mean Perimeter", "Mean Area", "Mean Smoothness", "Mean Compactness", "Mean Concavity", "Mean Concave Points", "Mean Symmetry", "Mean Fractal Dimension", "SE Radius", "SE Texture", "SE Perimeter", "SE Area", "SE Smoothness", "SE Compactness", "SE Concavity", "SE Concave Points", "SE Symmetry", "SE Fractal Dimension", "Worst Radius", "Worst Texture", "Worst Perimeter", "Worst Area", "Worst Smoothness", "Worst Compactness", "Worst Concavity", "Worst Concave Points", "Worst Symmetry", "Worst Fractal Dimension" ] # Example inputs benign_example = [9.504,12.44,60.34,273.9,0.1024,0.06492,0.02956,0.02076,0.1815,0.06905,0.2773,0.9768, 1.909,15.7,0.009606,0.01432,0.01985,0.01421,0.02027,0.002968,10.23,15.66,65.13,314.9, 0.1324,0.1148,0.08867,0.06227,0.245,0.07773] malignant_example = [11.42,20.38,77.58,386.1,0.1425,0.2839,0.2414,0.1052,0.2597,0.09744,0.4956,1.156, 3.445,27.23,0.00911,0.07458,0.05661,0.01867,0.05963,0.009208,14.91,26.5,98.87,567.7, 0.2098,0.8663,0.6869,0.2575,0.6638,0.173] def classify(model_choice, image=None, *features): """Classify using ViT (image) or NN (features).""" if model_choice == "ViT": if image is None: return "❌ Please upload an image for ViT classification." image = image.convert("RGB") input_tensor = transform(image).unsqueeze(0).to(device) with torch.no_grad(): output = vit_model(input_tensor) predicted_class = torch.argmax(output, dim=1).item() return class_names[predicted_class] elif model_choice == "Neural Network": if any(f is None for f in features): return "❌ Please enter all 30 numerical features." input_data = np.array(features).reshape(1, -1) input_data_std = scaler.transform(input_data) if scaler else input_data prediction = nn_model.predict(input_data_std) if nn_model else [[0, 1]] predicted_class = np.argmax(prediction) return class_names[predicted_class] # Gradio UI with gr.Blocks() as demo: gr.Markdown("## 🩺 Breast Cancer Classification Model") gr.Markdown("Select a model and provide input data to classify breast cancer as **Benign** or **Malignant**.") with gr.Row(): model_selector = gr.Radio(["ViT", "Neural Network"], label="🔬 Choose Model", value="ViT") image_input = gr.Image(type="pil", label="📷 Upload Image (for ViT)", visible=True) feature_inputs = [gr.Number(label=feature) for feature in feature_names] # Organizing feature inputs into rows of 3 columns with gr.Row(): with gr.Column(): for i in range(0, len(feature_inputs), 3): gr.Row([feature_inputs[j] for j in range(i, min(i+3, len(feature_inputs)))]) # Example buttons def fill_example(example): """Pre-fills example inputs.""" return {feature_inputs[i]: example[i] for i in range(len(feature_inputs))} with gr.Row(): example_btn_1 = gr.Button("🔴 Malignant Example") example_btn_2 = gr.Button("🔵 Benign Example") output_text = gr.Textbox(label="🔍 Model Prediction", interactive=False) def extract_features_from_file(file): """Reads a text file and extracts numerical features.""" if file is None: return "❌ Please upload a valid feature file." try: # Read and process file contents content = file.read().decode("utf-8").strip() values = [float(x) for x in content.replace(",", " ").split()] # Check if we have exactly 30 features if len(values) != 30: return "❌ The file must contain exactly 30 numerical values." return {feature_inputs[i]: values[i] for i in range(30)} except Exception as e: return f"❌ Error processing file: {e}" # Add file upload component file_input = gr.File(label="📂 Upload Feature File (for NN)", type="binary", visible=False) # Update UI logic to show file input for NN model def toggle_inputs(choice): image_visibility = choice == "ViT" feature_visibility = choice == "Neural Network" file_visibility = choice == "Neural Network" return [gr.update(visible=image_visibility)] + [gr.update(visible=feature_visibility)] * len(feature_inputs) + [gr.update(visible=file_visibility)] model_selector.change(toggle_inputs, model_selector, [image_input, *feature_inputs, file_input]) # Process uploaded file and populate feature fields file_input.change(extract_features_from_file, file_input, feature_inputs) # Toggle input fields based on model selection """Toggle visibility of inputs based on model selection.""" def toggle_inputs(choice): image_visibility = choice == "ViT" feature_visibility = choice == "Neural Network" return [gr.update(visible=image_visibility)] + [gr.update(visible=feature_visibility)] * len(feature_inputs) model_selector.change(toggle_inputs, model_selector, [image_input, *feature_inputs]) example_btn_1.click(lambda: fill_example(benign_example), None, feature_inputs) example_btn_2.click(lambda: fill_example(malignant_example), None, feature_inputs) classify_button = gr.Button("🚀 Classify") classify_button.click(classify, [model_selector, image_input] + feature_inputs, output_text) demo.launch()