noequalindi's picture
add gradio app and model
88282d2
import gradio as gr
import torch
from transformers import CvtForImageClassification, AutoFeatureExtractor
from PIL import Image
import os
# Configuración del dispositivo
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Cargar el extractor de características de Hugging Face
extractor = AutoFeatureExtractor.from_pretrained("microsoft/cvt-13")
# Definir las clases en el mismo orden que el modelo las predice
class_names = [
"glioma_tumor",
"meningioma_tumor",
"no_tumor",
"pituitary_tumor"
]
# Función para cargar el modelo (solo una vez)
def load_model():
model_dir = "models" # Ruta a los pesos
model_file_pytorch = "cvt_model.pth"
# Cargar los pesos del modelo desde el archivo .pth
checkpoint = torch.load(os.path.join(model_dir, model_file_pytorch), map_location=device)
# Cargar el modelo dependiendo de si tenemos el modelo completo o solo los pesos
if isinstance(checkpoint, CvtForImageClassification):
model_pytorch = checkpoint # El checkpoint ya es un modelo completo
else:
model_pytorch = CvtForImageClassification.from_pretrained("microsoft/cvt-13")
model_pytorch.load_state_dict(checkpoint) # Cargar los pesos en el modelo
model_pytorch.to(device)
model_pytorch.eval()
return model_pytorch
# Cargar el modelo una vez cuando la app se inicie
model_pytorch = load_model()
# Función para hacer predicción con la imagen cargada
def predict_image(image):
# Preprocesar la imagen usando el extractor de características
inputs = extractor(images=image, return_tensors="pt").to(device)
# Hacer la predicción con el modelo
with torch.no_grad():
outputs = model_pytorch(**inputs)
# Obtener los logits de la salida
logits = outputs.logits
# Convertir los logits en probabilidades
probabilities = torch.nn.functional.softmax(logits, dim=-1)
# Obtener la clase predicha (índice con mayor probabilidad)
predicted_index = probabilities.argmax(dim=-1).item()
# Mapear el índice de la clase predicha al nombre de la clase
predicted_class = class_names[predicted_index]
# Retornar el nombre de la clase predicha
return predicted_class
# Función para limpiar los inputs
def clear_inputs():
return None, None, None
# Definir el tema y la interfaz de Gradio
theme = gr.themes.Soft(
primary_hue="indigo",
secondary_hue="indigo",
).set(
background_fill_primary='#121212', # Dark background
background_fill_secondary='#1e1e1e',
block_background_fill='#1e1e1e', # Almost black
block_border_color='#333',
block_label_text_color='#fffff',
block_label_text_color_dark = '#fffff',
block_title_text_color_dark = '#fffff',
button_primary_background_fill='#4f46e5', # Violet
button_primary_background_fill_hover='#2563eb', # Light blue
button_secondary_background_fill='#4f46e5',
button_secondary_background_fill_hover='#2563eb',
input_background_fill='#333', # Dark grey
input_border_color='#444', # Intermediate grey
block_label_background_fill='#4f46e5',
block_label_background_fill_dark='#4f46e5',
slider_color='#2563eb',
slider_color_dark='#2563eb',
button_primary_text_color='#fffff',
button_secondary_text_color='#fffff',
button_secondary_background_fill_hover_dark='#4f46e5',
button_cancel_background_fill_hover='#444',
button_cancel_background_fill_hover_dark='#444'
)
with gr.Blocks(theme=theme, css="""
body, gradio-app {
background-image: url('https://b2928487.smushcdn.com/2928487/wp-content/uploads/2022/04/Brain-inspiredAI-2048x1365.jpeg?lossy=1&strip=1&webp=1');
background-size: cover;
color: white;
}
.gradio-container {
background-color: transparent;
background-image: url('https://b2928487.smushcdn.com/2928487/wp-content/uploads/2022/04/Brain-inspiredAI-2048x1365.jpeg?lossy=1&strip=1&webp=1') !important;
background-size: cover !important;
color: white;
}
.gradio-container .gr-dropdown-container select::after {
content: '▼';
color: white;
padding-left: 5px;
}
.gradio-container .gr-dropdown-container select:focus {
outline: none;
border-color: #4f46e5;
}
.gradio-container select {
color: white;
}
input, select, span, button, svg, .secondary-wrap {
color: white;
}
h1 {
color: white;
font-size: 4em;
margin: 20px auto;
}
.gradio-container h1 {
font-size: 5em;
color: white;
text-align: center;
text-shadow: 2px 2px 0px #8A2BE2,
4px 4px 0px #00000033;
text-transform: uppercase;
margin: 18px auto;
}
.gradio-container input {
color: white;
}
.gradio-container .output {
color: white;
}
.required-dropdown li {
color: white;
}
.button-style {
background-color: #4f46e5;
color: white;
}
.button-style:hover {
background-color: #2563eb;
color: white;
}
.gradio-container .contain textarea {
color: white;
font-weight: 600;
font-size: 1.5rem;
}
.contain textarea {
color: white;
font-weight: 600;
font-size: 1.5rem;
}
textarea {
color: white;
font-weight: 600;
font-size: 1.5rem;
background-color: black;
}
textarea .scroll-hide {
color: white;
}
.scroll-hide svelte-1f354aw {
color: white;
}
""") as demo:
gr.Markdown("# Brain Tumor Classification 🧠")
with gr.Row():
with gr.Column():
image_input = gr.Image(type="pil", label="Sube la imagen")
model_input = gr.Dropdown(choices=["model_1", "model_2"], label="Selecciona un modelo", elem_classes=['required-dropdown'])
classify_btn = gr.Button("Clasificar", elem_classes=['button-style'])
clear_btn = gr.Button("Limpiar")
with gr.Column():
prediction_output = gr.Textbox(label="Predicción")
classify_btn.click(predict_image, inputs=[image_input], outputs=prediction_output)
clear_btn.click(clear_inputs, inputs=[], outputs=[image_input, model_input, prediction_output])
demo.launch()