from fasthtml.common import *
from fastai.vision.all import *
import os
import time
from pathlib import Path
import urllib.request
from io import BytesIO
# Create necessary directories
os.makedirs('uploads', exist_ok=True)
# Function to load model - with fallback for testing
def load_model():
try:
model_path = 'levit.pkl'
# Check if model exists, if not try to download a sample model (for demo purposes)
if not os.path.exists(model_path):
print("Model not found. This is just for testing purposes.")
# In a real deployment, you'd want to handle this more gracefully
return None, ['class1', 'class2', 'class3']
learn = load_learner(model_path)
labels = learn.dls.vocab
print(f"Model loaded successfully with labels: {labels}")
return learn, labels
except Exception as e:
print(f"Error loading model: {e}")
# Fallback for testing
return None, ['class1', 'class2', 'class3']
# Load the model at startup
learn, labels = load_model()
# Create a FastHTML app
app, rt = fast_app()
# Define the prediction function
def predict(img_bytes):
try:
# If no model is loaded, return mock predictions for testing
if learn is None:
import random
mock_results = {label: random.random() for label in labels}
# Sort by values and normalize to ensure they sum to 1
total = sum(mock_results.values())
return {k: v/total for k, v in sorted(mock_results.items(), key=lambda x: x[1], reverse=True)}
# Real prediction with the model
img = PILImage.create(BytesIO(img_bytes))
img = img.resize((512, 512))
pred, pred_idx, probs = learn.predict(img)
return {labels[i]: float(probs[i]) for i in range(len(labels))}
except Exception as e:
print(f"Prediction error: {e}")
return {"Error": 1.0}
# Main page route
@rt("/")
def get():
# Create a form for image upload
upload_form = Form(
Div(
H1("FastAI Image Classifier"),
P("Upload an image to classify it using a pre-trained model."),
cls="instructions"
),
Div(
Input(type="file", name="image", accept="image/*", required=True,
hx_indicator="#loading"),
Button("Classify", type="submit"),
cls="upload-controls"
),
hx_post="/predict",
hx_target="#result",
hx_swap="innerHTML",
hx_encoding="multipart/form-data",
id="upload-form"
)
# Add loading indicator
loading = Div(
P("Processing your image..."),
id="loading",
cls="htmx-indicator"
)
# Container for results
result_container = Div(id="result", cls="result-container")
# Example section
examples = Div(
H2("Or try an example:"),
A("Example Image", href="#",
hx_get="/predict_example",
hx_target="#result",
hx_indicator="#loading"),
cls="examples-section"
)
# CSS styles
css = """
:root {
--primary-color: #3498db;
--secondary-color: #2c3e50;
--background-color: #f9f9f9;
--error-color: #e74c3c;
--shadow-color: rgba(0, 0, 0, 0.1);
--border-color: #ddd;
}
body {
font-family: system-ui, -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, Oxygen, Ubuntu, Cantarell, 'Open Sans', 'Helvetica Neue', sans-serif;
line-height: 1.6;
color: #333;
max-width: 800px;
margin: 0 auto;
padding: 20px;
background-color: #fff;
}
h1 {
color: var(--secondary-color);
margin-bottom: 1rem;
font-weight: 600;
}
h2 {
color: var(--primary-color);
margin-top: 1.5rem;
font-weight: 500;
}
.instructions {
margin-bottom: 20px;
}
.upload-controls {
display: flex;
gap: 10px;
margin-bottom: 30px;
align-items: center;
flex-wrap: wrap;
}
button {
background-color: var(--primary-color);
color: white;
border: none;
padding: 10px 15px;
border-radius: 4px;
cursor: pointer;
transition: background-color 0.3s;
font-weight: 500;
}
button:hover {
background-color: #2980b9;
}
input[type="file"] {
padding: 10px;
border: 1px solid var(--border-color);
border-radius: 4px;
flex-grow: 1;
}
#upload-form {
margin-bottom: 40px;
padding: 20px;
border-radius: 8px;
background-color: var(--background-color);
box-shadow: 0 2px 10px var(--shadow-color);
}
.result-container {
margin-top: 20px;
}
.prediction-results {
margin-top: 20px;
padding: 20px;
border: 1px solid var(--border-color);
border-radius: 8px;
background-color: var(--background-color);
box-shadow: 0 2px 8px var(--shadow-color);
}
.result-image {
max-width: 100%;
height: auto;
border-radius: 8px;
box-shadow: 0 2px 5px var(--shadow-color);
margin-bottom: 20px;
display: block;
}
.prediction-list {
margin-top: 15px;
}
.prediction-item {
padding: 12px 15px;
margin-bottom: 10px;
background-color: white;
border-radius: 6px;
box-shadow: 0 1px 3px var(--shadow-color);
}
.label-text {
margin-bottom: 8px;
font-weight: 500;
display: flex;
justify-content: space-between;
}
.examples-section {
margin-top: 30px;
padding-top: 20px;
border-top: 1px solid var(--border-color);
}
.htmx-indicator {
display: none;
padding: 15px;
background-color: #e8f4fc;
border-radius: 6px;
text-align: center;
margin: 15px 0;
box-shadow: 0 1px 3px var(--shadow-color);
}
.htmx-request .htmx-indicator {
display: block;
}
.progress-bar {
height: 10px;
background-color: #f0f0f0;
border-radius: 5px;
margin: 5px 0;
overflow: hidden;
}
.progress-fill {
height: 100%;
background-color: var(--primary-color);
width: 0;
transition: width 0.5s ease;
}
.error-message {
color: var(--error-color);
padding: 15px;
border: 1px solid var(--error-color);
border-radius: 5px;
background-color: #fde9e7;
}
a {
color: var(--primary-color);
text-decoration: none;
font-weight: 500;
}
a:hover {
text-decoration: underline;
}
/* Responsive styling */
@media (max-width: 600px) {
.upload-controls {
flex-direction: column;
align-items: stretch;
}
button {
width: 100%;
}
}
.model-info {
font-size: 0.9rem;
color: #666;
margin-top: 40px;
padding-top: 20px;
border-top: 1px solid var(--border-color);
}
"""
# Model information
model_info = Div(
P(f"Model: {'Model loaded successfully' if learn is not None else 'Demo mode - no model loaded'}"),
P(f"Classes: {', '.join(labels)}"),
cls="model-info"
)
return Titled("FastAI Image Classifier",
upload_form,
loading,
result_container,
examples,
model_info,
Style(css))
# Prediction route for uploaded images
@rt("/predict")
async def post(image: UploadFile):
try:
# Read the uploaded image
image_bytes = await image.read()
# Generate a unique filename to avoid conflicts
from datetime import datetime
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
safe_filename = f"{timestamp}_{image.filename.replace(' ', '_')}"
# Save the image temporarily
img_path = f"uploads/{safe_filename}"
with open(img_path, "wb") as f:
f.write(image_bytes)
# Add a small delay to make the loading indicator visible
time.sleep(0.5)
# Make a prediction
results = predict(image_bytes)
# Sort results by probability
sorted_results = dict(sorted(results.items(), key=lambda x: x[1], reverse=True))
top_results = dict(list(sorted_results.items())[:3])
# Create prediction items with progress bars
prediction_items = []
for label, prob in top_results.items():
percentage = int(prob * 100)
prediction_items.append(
Div(
Div(
Span(f"{label}"),
Span(f"{percentage}%"),
cls="label-text"
),
Div(
Div(cls="progress-fill", style=f"width: {percentage}%;"),
cls="progress-bar"
),
cls="prediction-item"
)
)
# Create result HTML
result_html = Div(
H2("Prediction Results:"),
Img(src=f"/image/{safe_filename}", cls="result-image", alt="Uploaded image"),
Div(*prediction_items, cls="prediction-list"),
cls="prediction-results"
)
return result_html
except Exception as e:
return Div(
H2("Error"),
P(f"An error occurred during prediction: {str(e)}"),
cls="error-message"
)
# Route to serve saved images
@rt("/image/{filename}")
def get(filename: str):
file_path = f"uploads/{filename}"
if os.path.exists(file_path):
return FileResponse(file_path)
else:
return Div(
H2("Error"),
P("Image not found."),
cls="error-message"
)
# Route for example image
@rt("/predict_example")
def get():
try:
# Path to example image
example_path = "image.jpg"
# Check if example image exists
if os.path.exists(example_path):
with open(example_path, "rb") as f:
image_bytes = f.read()
# Save the example image to uploads
example_name = "example.jpg"
with open(f"uploads/{example_name}", "wb") as f:
f.write(image_bytes)
# Add a small delay to make the loading indicator visible
time.sleep(0.5)
# Make a prediction
results = predict(image_bytes)
# Sort results by probability
sorted_results = dict(sorted(results.items(), key=lambda x: x[1], reverse=True))
top_results = dict(list(sorted_results.items())[:3])
# Create prediction items with progress bars
prediction_items = []
for label, prob in top_results.items():
percentage = int(prob * 100)
prediction_items.append(
Div(
Div(
Span(f"{label}"),
Span(f"{percentage}%"),
cls="label-text"
),
Div(
Div(cls="progress-fill", style=f"width: {percentage}%;"),
cls="progress-bar"
),
cls="prediction-item"
)
)
# Create result HTML
result_html = Div(
H2("Prediction Results:"),
Img(src=f"/image/{example_name}", cls="result-image", alt="Example image"),
Div(*prediction_items, cls="prediction-list"),
P("This is a demonstration using the provided example image.", style="font-style: italic; color: #666;"),
cls="prediction-results"
)
return result_html
else:
return Div(
H2("Example Not Found"),
P("The example image 'image.jpg' was not found. Please try uploading your own image."),
cls="error-message"
)
except Exception as e:
return Div(
H2("Error"),
P(f"An error occurred with the example: {str(e)}"),
cls="error-message"
)
# Health check endpoint (useful for Docker/Kubernetes)
@rt("/health")
def get():
return {"status": "ok", "model_loaded": learn is not None}
# Run the app
if __name__ == "__main__":
# Use environment variables if available (common in Docker)
host = os.environ.get("HOST", "0.0.0.0")
port = int(os.environ.get("PORT", 8000))
print(f"Starting FastHTML server on {host}:{port}")
serve(app=app, host=host, port=port)