bambu-a1-mini / app.py
sgbaird's picture
update livestream url
84fd316 verified
raw
history blame
17.8 kB
import gradio as gr
import paho.mqtt.client as mqtt
import json
import time
import threading
import os
import base64
from PIL import Image
import io
import numpy as np
import logging
import sys
import random
import cv2
from PIL import ImageDraw
import requests
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
handlers=[
logging.StreamHandler(sys.stdout),
logging.FileHandler('app.log')
]
)
logger = logging.getLogger("bambu-analysis")
HOST = os.environ.get("MQTT_HOST", "default_host")
PORT = int(os.environ.get("MQTT_PORT", 1883))
USERNAME = os.environ.get("MQTT_USERNAME", "default_user")
PASSWORD = os.environ.get("MQTT_PASSWORD", "default_pass")
DEFAULT_SERIAL = os.environ.get("DEFAULT_SERIAL", "default_serial")
print(f"Connecting to MQTT at {HOST}:{PORT} with user {USERNAME}")
if os.environ.get("host"):
HOST = os.environ.get("host")
if os.environ.get("port"):
PORT = int(os.environ.get("port"))
if os.environ.get("username"):
USERNAME = os.environ.get("username")
if os.environ.get("password"):
PASSWORD = os.environ.get("password")
logger.info(f"MQTT Configuration: HOST={HOST}, PORT={PORT}, USERNAME={USERNAME}")
latest_data = {
"bed_temperature": "N/A",
"nozzle_temperature": "N/A",
"status": "N/A",
"update_time": "Waiting for data...",
}
client = None
response_topic = None # Will be set dynamically
def create_client(host, port, username, password):
global client
client = mqtt.Client()
client.username_pw_set(username, password)
client.tls_set(tls_version=mqtt.ssl.PROTOCOL_TLS)
client.on_connect = on_connect
client.on_message = on_message
client.connect(host, port)
client.loop_start()
def on_connect(client, userdata, flags, rc):
logger.info(f"Connected with result code {rc}")
def on_message(client, userdata, message):
global latest_data
logger.info("Received message")
try:
data = json.loads(message.payload)
latest_data["bed_temperature"] = data.get("bed_temperature", "N/A")
latest_data["nozzle_temperature"] = data.get("nozzle_temperature", "N/A")
latest_data["status"] = data.get("status", "N/A")
latest_data["update_time"] = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
except Exception as e:
logger.error(f"Error parsing MQTT message: {e}")
def get_data(serial=DEFAULT_SERIAL):
global client, response_topic
if client is None:
create_client(HOST, PORT, USERNAME, PASSWORD)
request_topic = f"bambu_a1_mini/request/{serial}"
response_topic = f"bambu_a1_mini/response/{serial}"
logger.info(f"Subscribing to {response_topic}")
client.subscribe(response_topic)
logger.info(f"Publishing request to {request_topic}")
client.publish(request_topic, json.dumps("HI"))
global latest_data
latest_data["bed_temperature"] = "N/A"
timeout = 10
while latest_data["bed_temperature"] == "N/A" and timeout > 0:
time.sleep(1)
timeout -= 1
return (
latest_data["status"],
latest_data["bed_temperature"],
latest_data["nozzle_temperature"],
latest_data["update_time"]
)
def send_print_parameters(nozzle_temp, bed_temp, print_speed, fan_speed):
serial = DEFAULT_SERIAL
logger.info(f"Sending parameters to {serial}: nozzle={nozzle_temp}, bed={bed_temp}, speed={print_speed}, fan={fan_speed}")
try:
params = {
'nozzle_temp': nozzle_temp,
'bed_temp': bed_temp,
'print_speed': print_speed,
'fan_speed': fan_speed
}
request_topic = f"bambu_a1_mini/request/{serial}"
if client:
client.publish(request_topic, json.dumps({
'command': 'set_parameters',
'parameters': params
}))
logger.info("Parameters sent successfully")
return "Parameters sent successfully"
else:
logger.warning("MQTT not connected, parameters not sent")
return "MQTT not connected, parameters not sent"
except Exception as e:
logger.error(f"Error sending parameters: {e}")
return f"Error sending parameters: {e}"
def get_image_base64(image):
if image is None:
logger.warning("No image to encode")
return None
try:
if isinstance(image, np.ndarray):
image = Image.fromarray(image)
buffer = io.BytesIO()
image.save(buffer, format="PNG")
img_str = base64.b64encode(buffer.getvalue()).decode('utf-8')
logger.info(f"Image encoded to base64 (length: {len(img_str)})")
return img_str
except Exception as e:
logger.error(f"Error encoding image: {e}")
return None
def get_test_image(image_name=None):
import os
import random
test_dir = os.path.join(os.path.dirname(__file__), "test_images")
if not os.path.exists(test_dir):
logger.error(f"Test images directory not found: {test_dir}")
return None
image_files = [f for f in os.listdir(test_dir)
if f.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp'))]
if not image_files:
logger.error("No test images found")
return None
if image_name and image_name in image_files:
image_path = os.path.join(test_dir, image_name)
else:
image_path = os.path.join(test_dir, random.choice(image_files))
logger.info(f"Using test image: {image_path}")
try:
return Image.open(image_path)
except Exception as e:
logger.error(f"Failed to open test image: {e}")
return None
def capture_image(url=None, use_test_image=False, test_image_name=None):
if use_test_image:
logger.info("Using test image instead of URL")
test_img = get_test_image(test_image_name)
if test_img:
return test_img
else:
logger.warning("Failed to get specified test image, trying URL")
if url:
try:
logger.info(f"Capturing image from URL: {url}")
response = requests.get(url, timeout=10)
if response.status_code == 200:
return Image.open(io.BytesIO(response.content))
else:
logger.error(f"Failed to get image from URL: {response.status_code}")
except Exception as e:
logger.error(f"Error capturing image from URL: {e}")
logger.info("URL capture failed or not provided, using random test image")
return get_test_image()
def health_check():
status = {
"app": "running",
"time": time.strftime("%Y-%m-%d %H:%M:%S"),
"mqtt_connected": client is not None,
"latest_update": latest_data["update_time"]
}
logger.info(f"Health check: {status}")
return status
demo = gr.Blocks(title="Bambu A1 Mini Print Control")
with demo:
gr.Markdown("# Bambu A1 Mini Print Control")
with gr.Row():
refresh_btn = gr.Button("Refresh Status")
with gr.Row():
current_status = gr.Textbox(label="Printer Status", value="N/A", interactive=False)
current_bed_temp = gr.Textbox(label="Current Bed Temperature", value="N/A", interactive=False)
current_nozzle_temp = gr.Textbox(label="Current Nozzle Temperature", value="N/A", interactive=False)
last_update = gr.Textbox(label="Last Update", value="N/A", interactive=False)
with gr.Row():
# Left column for image and capture button
with gr.Column(scale=2):
captured_image = gr.Image(label="Current Print Image", type="pil")
capture_btn = gr.Button("Capture Image")
# Right column for queue status and livestream
with gr.Column(scale=1):
gr.Markdown("### YouTube Livestream")
iframe_html = '''
<div style="position: relative; width: 100%; padding-top: 56.25%;">
<iframe
style="position: absolute; top: 0; left: 0; width: 100%; height: 100%;"
src="https://www.youtube.com/embed/mGlbI54SxQg"
title="Bambu A1mini Livestream"
frameborder="0"
allow="accelerometer; autoplay; clipboard-write; encrypted-media; gyroscope; picture-in-picture; web-share"
referrerpolicy="strict-origin-when-cross-origin"
allowfullscreen>
</iframe>
</div>
'''
gr.HTML(iframe_html)
with gr.Row():
with gr.Column():
nozzle_temp = gr.Slider(minimum=180, maximum=250, step=1, value=200, label="Nozzle Temperature (°C)")
bed_temp = gr.Slider(minimum=40, maximum=100, step=1, value=60, label="Bed Temperature (°C)")
print_speed = gr.Slider(minimum=20, maximum=150, step=1, value=60, label="Print Speed (mm/s)")
fan_speed = gr.Slider(minimum=0, maximum=100, step=1, value=100, label="Fan Speed (%)")
send_params_btn = gr.Button("Send Print Parameters")
refresh_btn.click(
fn=get_data,
outputs=[current_status, current_bed_temp, current_nozzle_temp, last_update]
)
capture_btn.click(
fn=capture_image,
outputs=[captured_image]
)
send_params_btn.click(
fn=send_print_parameters,
inputs=[nozzle_temp, bed_temp, print_speed, fan_speed],
outputs=[current_status]
)
def api_get_data():
logger.info("API call: get_data")
return get_data()
def api_capture_frame(url=None, use_test_image=False, test_image_name=None):
logger.info(f"API call: capture_frame with URL: {url}, use_test_image: {use_test_image}")
try:
img = capture_image(url, use_test_image, test_image_name)
if img:
if img.mode == 'RGBA':
img = img.convert('RGB')
buffered = io.BytesIO()
img.save(buffered, format="JPEG")
img_str = base64.b64encode(buffered.getvalue()).decode()
return {
"success": True,
"image": img_str
}
else:
return {
"success": False,
"error": "Failed to capture image"
}
except Exception as e:
logger.error(f"Error in capture_frame: {e}")
return {
"success": False,
"error": str(e)
}
def api_lambda(img_data=None, param_1=200, param_2=60, param_3=60, param_4=100, use_test_image=False, test_image_name=None):
logger.info(f"API call: lambda with params: {param_1}, {param_2}, {param_3}, {param_4}, use_test_image: {use_test_image}, test_image_name: {test_image_name}")
try:
img = None
if use_test_image:
logger.info(f"Lambda using test image: {test_image_name}")
img = get_test_image(test_image_name)
elif img_data and isinstance(img_data, str) and (img_data.startswith('http://') or img_data.startswith('https://')):
logger.info(f"Lambda received image URL: {img_data}")
img = capture_image(img_data)
elif img_data and isinstance(img_data, str):
try:
logger.info("Lambda received base64 image data")
img_bytes = base64.b64decode(img_data)
img = Image.open(io.BytesIO(img_bytes))
except Exception as e:
logger.error(f"Failed to decode base64 image: {e}")
if img is None:
logger.info("No valid image data received, using default test image")
img = get_test_image()
if img:
img_array = np.array(img)
quality_level = 'low'
if 190 <= param_1 <= 210 and param_3 <= 50 and param_4 >= 80:
quality_level = 'high'
elif 185 <= param_1 <= 215 and param_3 <= 70 and param_4 >= 60:
quality_level = 'medium'
if quality_level == 'high':
missing_rate = 0.02
excess_rate = 0.01
stringing_rate = 0.01
elif quality_level == 'medium':
missing_rate = 0.05
excess_rate = 0.03
stringing_rate = 0.02
else: # low
missing_rate = 0.10
excess_rate = 0.07
stringing_rate = 0.05
uniformity_score = 1.0 - (missing_rate + excess_rate + stringing_rate)
print_quality_score = 1.0 - (missing_rate * 2.0 + excess_rate * 1.5 + stringing_rate * 1.0)
print_quality_score = max(0, min(1, print_quality_score))
print_speed_score = param_3 / 150.0
print_speed_score = max(0, min(1, print_speed_score))
material_efficiency_score = 1.0 - excess_rate * 3.0
material_efficiency_score = max(0, min(1, material_efficiency_score))
total_performance_score = (
0.5 * print_quality_score +
0.3 * print_speed_score +
0.2 * material_efficiency_score
)
img_draw = img.copy()
draw = ImageDraw.Draw(img_draw)
draw.text((10, 10), f"Quality: {quality_level.upper()}", fill=(255, 0, 0))
draw.text((10, 30), f"Missing: {missing_rate:.2f}", fill=(255, 0, 0))
draw.text((10, 50), f"Excess: {excess_rate:.2f}", fill=(255, 0, 0))
draw.text((10, 70), f"Stringing: {stringing_rate:.2f}", fill=(255, 0, 0))
result = {
"success": True,
"missing_rate": missing_rate,
"excess_rate": excess_rate,
"stringing_rate": stringing_rate,
"uniformity_score": uniformity_score,
"print_quality_score": print_quality_score,
"print_speed_score": print_speed_score,
"material_efficiency_score": material_efficiency_score,
"total_performance_score": total_performance_score
}
if img_draw.mode == 'RGBA':
img_draw = img_draw.convert('RGB')
buffered = io.BytesIO()
img_draw.save(buffered, format="JPEG")
img_str = base64.b64encode(buffered.getvalue()).decode()
result["image"] = img_str
return result
else:
return {
"success": False,
"error": "Failed to get image"
}
except Exception as e:
logger.error(f"Error in lambda: {e}")
return {
"error": str(e)
}
def api_send_print_parameters(nozzle_temp=200, bed_temp=60, print_speed=60, fan_speed=100):
logger.info(f"API call: send_print_parameters with nozzle={nozzle_temp}, bed={bed_temp}, speed={print_speed}, fan={fan_speed}")
return send_print_parameters(nozzle_temp, bed_temp, print_speed, fan_speed)
api_json_output = gr.JSON()
api_text_output = gr.Textbox()
capture_frame_api = demo.load(
fn=api_capture_frame,
inputs=[
gr.Textbox(label="Image URL"),
gr.Checkbox(label="Use Test Image", value=False),
gr.Textbox(label="Test Image Name", value="")
],
outputs=api_json_output,
api_name="capture_frame"
)
lambda_api = demo.load(
fn=api_lambda,
inputs=[
gr.Textbox(label="Image Base64 or URL"),
gr.Number(label="Nozzle Temperature", value=200),
gr.Number(label="Bed Temperature", value=60),
gr.Number(label="Print Speed", value=60),
gr.Number(label="Fan Speed", value=100),
gr.Checkbox(label="Use Test Image", value=False),
gr.Textbox(label="Test Image Name", value="")
],
outputs=api_json_output,
api_name="lambda"
)
get_data_api = demo.load(
fn=api_get_data,
inputs=None,
outputs=api_json_output,
api_name="get_data"
)
send_params_api = demo.load(
fn=api_send_print_parameters,
inputs=[
gr.Number(label="Nozzle Temperature", value=200),
gr.Number(label="Bed Temperature", value=60),
gr.Number(label="Print Speed", value=60),
gr.Number(label="Fan Speed", value=100)
],
outputs=api_text_output,
api_name="send_print_parameters"
)
if __name__ == "__main__":
logger.info("Starting Bambu A1 Mini Print Control application")
try:
logger.info("Initializing MQTT client")
create_client(HOST, PORT, USERNAME, PASSWORD)
except Exception as e:
logger.error(f"Failed to initialize MQTT: {e}")
demo.queue().launch(
show_error=True,
share=False,
server_name="0.0.0.0",
server_port=7860
)