from fasthtml.common import * from fastai.vision.all import * import os import time from pathlib import Path import urllib.request from io import BytesIO # Create necessary directories os.makedirs('uploads', exist_ok=True) # Function to load model - with fallback for testing def load_model(): try: model_path = 'levit.pkl' # Check if model exists, if not try to download a sample model (for demo purposes) if not os.path.exists(model_path): print("Model not found. This is just for testing purposes.") # In a real deployment, you'd want to handle this more gracefully return None, ['class1', 'class2', 'class3'] learn = load_learner(model_path) labels = learn.dls.vocab print(f"Model loaded successfully with labels: {labels}") return learn, labels except Exception as e: print(f"Error loading model: {e}") # Fallback for testing return None, ['class1', 'class2', 'class3'] # Load the model at startup learn, labels = load_model() # Create a FastHTML app app, rt = fast_app() # Define the prediction function def predict(img_bytes): try: # If no model is loaded, return mock predictions for testing if learn is None: import random mock_results = {label: random.random() for label in labels} # Sort by values and normalize to ensure they sum to 1 total = sum(mock_results.values()) return {k: v/total for k, v in sorted(mock_results.items(), key=lambda x: x[1], reverse=True)} # Real prediction with the model img = PILImage.create(BytesIO(img_bytes)) img = img.resize((512, 512)) pred, pred_idx, probs = learn.predict(img) return {labels[i]: float(probs[i]) for i in range(len(labels))} except Exception as e: print(f"Prediction error: {e}") return {"Error": 1.0} # Main page route @rt("/") def get(): # Create a form for image upload upload_form = Form( Div( H1("FastAI Image Classifier"), P("Upload an image to classify it using a pre-trained model."), cls="instructions" ), Div( Input(type="file", name="image", accept="image/*", required=True, hx_indicator="#loading"), Button("Classify", type="submit"), cls="upload-controls" ), hx_post="/predict", hx_target="#result", hx_swap="innerHTML", hx_encoding="multipart/form-data", id="upload-form" ) # Add loading indicator loading = Div( P("Processing your image..."), id="loading", cls="htmx-indicator" ) # Container for results result_container = Div(id="result", cls="result-container") # Example section examples = Div( H2("Or try an example:"), A("Example Image", href="#", hx_get="/predict_example", hx_target="#result", hx_indicator="#loading"), cls="examples-section" ) # CSS styles css = """ :root { --primary-color: #3498db; --secondary-color: #2c3e50; --background-color: #f9f9f9; --error-color: #e74c3c; --shadow-color: rgba(0, 0, 0, 0.1); --border-color: #ddd; } body { font-family: system-ui, -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, Oxygen, Ubuntu, Cantarell, 'Open Sans', 'Helvetica Neue', sans-serif; line-height: 1.6; color: #333; max-width: 800px; margin: 0 auto; padding: 20px; background-color: #fff; } h1 { color: var(--secondary-color); margin-bottom: 1rem; font-weight: 600; } h2 { color: var(--primary-color); margin-top: 1.5rem; font-weight: 500; } .instructions { margin-bottom: 20px; } .upload-controls { display: flex; gap: 10px; margin-bottom: 30px; align-items: center; flex-wrap: wrap; } button { background-color: var(--primary-color); color: white; border: none; padding: 10px 15px; border-radius: 4px; cursor: pointer; transition: background-color 0.3s; font-weight: 500; } button:hover { background-color: #2980b9; } input[type="file"] { padding: 10px; border: 1px solid var(--border-color); border-radius: 4px; flex-grow: 1; } #upload-form { margin-bottom: 40px; padding: 20px; border-radius: 8px; background-color: var(--background-color); box-shadow: 0 2px 10px var(--shadow-color); } .result-container { margin-top: 20px; } .prediction-results { margin-top: 20px; padding: 20px; border: 1px solid var(--border-color); border-radius: 8px; background-color: var(--background-color); box-shadow: 0 2px 8px var(--shadow-color); } .result-image { max-width: 100%; height: auto; border-radius: 8px; box-shadow: 0 2px 5px var(--shadow-color); margin-bottom: 20px; display: block; } .prediction-list { margin-top: 15px; } .prediction-item { padding: 12px 15px; margin-bottom: 10px; background-color: white; border-radius: 6px; box-shadow: 0 1px 3px var(--shadow-color); } .label-text { margin-bottom: 8px; font-weight: 500; display: flex; justify-content: space-between; } .examples-section { margin-top: 30px; padding-top: 20px; border-top: 1px solid var(--border-color); } .htmx-indicator { display: none; padding: 15px; background-color: #e8f4fc; border-radius: 6px; text-align: center; margin: 15px 0; box-shadow: 0 1px 3px var(--shadow-color); } .htmx-request .htmx-indicator { display: block; } .progress-bar { height: 10px; background-color: #f0f0f0; border-radius: 5px; margin: 5px 0; overflow: hidden; } .progress-fill { height: 100%; background-color: var(--primary-color); width: 0; transition: width 0.5s ease; } .error-message { color: var(--error-color); padding: 15px; border: 1px solid var(--error-color); border-radius: 5px; background-color: #fde9e7; } a { color: var(--primary-color); text-decoration: none; font-weight: 500; } a:hover { text-decoration: underline; } /* Responsive styling */ @media (max-width: 600px) { .upload-controls { flex-direction: column; align-items: stretch; } button { width: 100%; } } .model-info { font-size: 0.9rem; color: #666; margin-top: 40px; padding-top: 20px; border-top: 1px solid var(--border-color); } """ # Model information model_info = Div( P(f"Model: {'Model loaded successfully' if learn is not None else 'Demo mode - no model loaded'}"), P(f"Classes: {', '.join(labels)}"), cls="model-info" ) return Titled("FastAI Image Classifier", upload_form, loading, result_container, examples, model_info, Style(css)) # Prediction route for uploaded images @rt("/predict") async def post(image: UploadFile): try: # Read the uploaded image image_bytes = await image.read() # Generate a unique filename to avoid conflicts from datetime import datetime timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") safe_filename = f"{timestamp}_{image.filename.replace(' ', '_')}" # Save the image temporarily img_path = f"uploads/{safe_filename}" with open(img_path, "wb") as f: f.write(image_bytes) # Add a small delay to make the loading indicator visible time.sleep(0.5) # Make a prediction results = predict(image_bytes) # Sort results by probability sorted_results = dict(sorted(results.items(), key=lambda x: x[1], reverse=True)) top_results = dict(list(sorted_results.items())[:3]) # Create prediction items with progress bars prediction_items = [] for label, prob in top_results.items(): percentage = int(prob * 100) prediction_items.append( Div( Div( Span(f"{label}"), Span(f"{percentage}%"), cls="label-text" ), Div( Div(cls="progress-fill", style=f"width: {percentage}%;"), cls="progress-bar" ), cls="prediction-item" ) ) # Create result HTML result_html = Div( H2("Prediction Results:"), Img(src=f"/image/{safe_filename}", cls="result-image", alt="Uploaded image"), Div(*prediction_items, cls="prediction-list"), cls="prediction-results" ) return result_html except Exception as e: return Div( H2("Error"), P(f"An error occurred during prediction: {str(e)}"), cls="error-message" ) # Route to serve saved images @rt("/image/{filename}") def get(filename: str): file_path = f"uploads/{filename}" if os.path.exists(file_path): return FileResponse(file_path) else: return Div( H2("Error"), P("Image not found."), cls="error-message" ) # Route for example image @rt("/predict_example") def get(): try: # Path to example image example_path = "image.jpg" # Check if example image exists if os.path.exists(example_path): with open(example_path, "rb") as f: image_bytes = f.read() # Save the example image to uploads example_name = "example.jpg" with open(f"uploads/{example_name}", "wb") as f: f.write(image_bytes) # Add a small delay to make the loading indicator visible time.sleep(0.5) # Make a prediction results = predict(image_bytes) # Sort results by probability sorted_results = dict(sorted(results.items(), key=lambda x: x[1], reverse=True)) top_results = dict(list(sorted_results.items())[:3]) # Create prediction items with progress bars prediction_items = [] for label, prob in top_results.items(): percentage = int(prob * 100) prediction_items.append( Div( Div( Span(f"{label}"), Span(f"{percentage}%"), cls="label-text" ), Div( Div(cls="progress-fill", style=f"width: {percentage}%;"), cls="progress-bar" ), cls="prediction-item" ) ) # Create result HTML result_html = Div( H2("Prediction Results:"), Img(src=f"/image/{example_name}", cls="result-image", alt="Example image"), Div(*prediction_items, cls="prediction-list"), P("This is a demonstration using the provided example image.", style="font-style: italic; color: #666;"), cls="prediction-results" ) return result_html else: return Div( H2("Example Not Found"), P("The example image 'image.jpg' was not found. Please try uploading your own image."), cls="error-message" ) except Exception as e: return Div( H2("Error"), P(f"An error occurred with the example: {str(e)}"), cls="error-message" ) # Health check endpoint (useful for Docker/Kubernetes) @rt("/health") def get(): return {"status": "ok", "model_loaded": learn is not None} # Run the app if __name__ == "__main__": # Use environment variables if available (common in Docker) host = os.environ.get("HOST", "0.0.0.0") port = int(os.environ.get("PORT", 8000)) print(f"Starting FastHTML server on {host}:{port}") serve(app=app, host=host, port=port)