import gradio as gr import torch from transformers import CvtForImageClassification, AutoFeatureExtractor from PIL import Image import os # Configuración del dispositivo device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Cargar el extractor de características de Hugging Face extractor = AutoFeatureExtractor.from_pretrained("microsoft/cvt-13") # Definir las clases en el mismo orden que el modelo las predice class_names = [ "glioma_tumor", "meningioma_tumor", "no_tumor", "pituitary_tumor" ] # Función para cargar el modelo (solo una vez) def load_model(): model_dir = "models" # Ruta a los pesos model_file_pytorch = "cvt_model.pth" # Cargar los pesos del modelo desde el archivo .pth checkpoint = torch.load(os.path.join(model_dir, model_file_pytorch), map_location=device) # Cargar el modelo dependiendo de si tenemos el modelo completo o solo los pesos if isinstance(checkpoint, CvtForImageClassification): model_pytorch = checkpoint # El checkpoint ya es un modelo completo else: model_pytorch = CvtForImageClassification.from_pretrained("microsoft/cvt-13") model_pytorch.load_state_dict(checkpoint) # Cargar los pesos en el modelo model_pytorch.to(device) model_pytorch.eval() return model_pytorch # Cargar el modelo una vez cuando la app se inicie model_pytorch = load_model() # Función para hacer predicción con la imagen cargada def predict_image(image): # Preprocesar la imagen usando el extractor de características inputs = extractor(images=image, return_tensors="pt").to(device) # Hacer la predicción con el modelo with torch.no_grad(): outputs = model_pytorch(**inputs) # Obtener los logits de la salida logits = outputs.logits # Convertir los logits en probabilidades probabilities = torch.nn.functional.softmax(logits, dim=-1) # Obtener la clase predicha (índice con mayor probabilidad) predicted_index = probabilities.argmax(dim=-1).item() # Mapear el índice de la clase predicha al nombre de la clase predicted_class = class_names[predicted_index] # Retornar el nombre de la clase predicha return predicted_class # Función para limpiar los inputs def clear_inputs(): return None, None, None # Definir el tema y la interfaz de Gradio theme = gr.themes.Soft( primary_hue="indigo", secondary_hue="indigo", ).set( background_fill_primary='#121212', # Dark background background_fill_secondary='#1e1e1e', block_background_fill='#1e1e1e', # Almost black block_border_color='#333', block_label_text_color='#fffff', block_label_text_color_dark = '#fffff', block_title_text_color_dark = '#fffff', button_primary_background_fill='#4f46e5', # Violet button_primary_background_fill_hover='#2563eb', # Light blue button_secondary_background_fill='#4f46e5', button_secondary_background_fill_hover='#2563eb', input_background_fill='#333', # Dark grey input_border_color='#444', # Intermediate grey block_label_background_fill='#4f46e5', block_label_background_fill_dark='#4f46e5', slider_color='#2563eb', slider_color_dark='#2563eb', button_primary_text_color='#fffff', button_secondary_text_color='#fffff', button_secondary_background_fill_hover_dark='#4f46e5', button_cancel_background_fill_hover='#444', button_cancel_background_fill_hover_dark='#444' ) with gr.Blocks(theme=theme, css=""" body, gradio-app { background-image: url('https://b2928487.smushcdn.com/2928487/wp-content/uploads/2022/04/Brain-inspiredAI-2048x1365.jpeg?lossy=1&strip=1&webp=1'); background-size: cover; color: white; } .gradio-container { background-color: transparent; background-image: url('https://b2928487.smushcdn.com/2928487/wp-content/uploads/2022/04/Brain-inspiredAI-2048x1365.jpeg?lossy=1&strip=1&webp=1') !important; background-size: cover !important; color: white; } .gradio-container .gr-dropdown-container select::after { content: '▼'; color: white; padding-left: 5px; } .gradio-container .gr-dropdown-container select:focus { outline: none; border-color: #4f46e5; } .gradio-container select { color: white; } input, select, span, button, svg, .secondary-wrap { color: white; } h1 { color: white; font-size: 4em; margin: 20px auto; } .gradio-container h1 { font-size: 5em; color: white; text-align: center; text-shadow: 2px 2px 0px #8A2BE2, 4px 4px 0px #00000033; text-transform: uppercase; margin: 18px auto; } .gradio-container input { color: white; } .gradio-container .output { color: white; } .required-dropdown li { color: white; } .button-style { background-color: #4f46e5; color: white; } .button-style:hover { background-color: #2563eb; color: white; } .gradio-container .contain textarea { color: white; font-weight: 600; font-size: 1.5rem; } .contain textarea { color: white; font-weight: 600; font-size: 1.5rem; } textarea { color: white; font-weight: 600; font-size: 1.5rem; background-color: black; } textarea .scroll-hide { color: white; } .scroll-hide svelte-1f354aw { color: white; } """) as demo: gr.Markdown("# Brain Tumor Classification 🧠") with gr.Row(): with gr.Column(): image_input = gr.Image(type="pil", label="Sube la imagen") model_input = gr.Dropdown(choices=["model_1", "model_2"], label="Selecciona un modelo", elem_classes=['required-dropdown']) classify_btn = gr.Button("Clasificar", elem_classes=['button-style']) clear_btn = gr.Button("Limpiar") with gr.Column(): prediction_output = gr.Textbox(label="Predicción") classify_btn.click(predict_image, inputs=[image_input], outputs=prediction_output) clear_btn.click(clear_inputs, inputs=[], outputs=[image_input, model_input, prediction_output]) demo.launch()