Update app.py
Browse files
app.py
CHANGED
@@ -1,57 +1,135 @@
|
|
|
|
1 |
import gradio as gr
|
2 |
-
import
|
|
|
3 |
import numpy as np
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
from
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# app.py
|
2 |
import gradio as gr
|
3 |
+
import cv2
|
4 |
+
import mediapipe as mp
|
5 |
import numpy as np
|
6 |
+
import time
|
7 |
+
from pathlib import Path
|
8 |
+
import tempfile
|
9 |
+
import os
|
10 |
+
from face_reconstruction_main import FaceMeshDetector
|
11 |
+
import plotly.graph_objects as go
|
12 |
+
import json
|
13 |
+
|
14 |
+
# Initialize the detector
|
15 |
+
detector = FaceMeshDetector()
|
16 |
+
|
17 |
+
def process_frame(frame):
|
18 |
+
"""Process a single frame and return the image with face mesh overlay"""
|
19 |
+
if frame is None:
|
20 |
+
return None
|
21 |
+
|
22 |
+
img, faces = detector.detect_faces(frame)
|
23 |
+
return img, faces
|
24 |
+
|
25 |
+
def create_3d_plot(vertices):
|
26 |
+
"""Create a 3D plotly figure from vertices"""
|
27 |
+
x, y, z = zip(*vertices)
|
28 |
+
|
29 |
+
fig = go.Figure(data=[go.Scatter3d(
|
30 |
+
x=x, y=y, z=z,
|
31 |
+
mode='markers',
|
32 |
+
marker=dict(
|
33 |
+
size=2,
|
34 |
+
color=z,
|
35 |
+
colorscale='Viridis',
|
36 |
+
)
|
37 |
+
)])
|
38 |
+
|
39 |
+
fig.update_layout(
|
40 |
+
scene=dict(
|
41 |
+
xaxis_title='X',
|
42 |
+
yaxis_title='Y',
|
43 |
+
zaxis_title='Z'
|
44 |
+
),
|
45 |
+
margin=dict(l=0, r=0, b=0, t=0)
|
46 |
+
)
|
47 |
+
|
48 |
+
return fig
|
49 |
+
|
50 |
+
def save_obj_and_display(frame):
|
51 |
+
"""Save frame as OBJ and return both file path and 3D visualization"""
|
52 |
+
if frame is None:
|
53 |
+
return None, None, "No frame provided"
|
54 |
+
|
55 |
+
img, faces = process_frame(frame)
|
56 |
+
|
57 |
+
if not faces:
|
58 |
+
return None, None, "No face detected in the frame"
|
59 |
+
|
60 |
+
# Scale vertices
|
61 |
+
vertices = np.array(faces[0]) * 150
|
62 |
+
|
63 |
+
# Create temporary directory if it doesn't exist
|
64 |
+
temp_dir = Path('temp')
|
65 |
+
temp_dir.mkdir(exist_ok=True)
|
66 |
+
|
67 |
+
# Generate timestamp
|
68 |
+
timestamp = time.strftime("%Y%m%d-%H%M%S")
|
69 |
+
filename = f"face_mesh_{timestamp}"
|
70 |
+
|
71 |
+
# Save OBJ file
|
72 |
+
obj_path = temp_dir / f"{filename}.obj"
|
73 |
+
success = detector.save_obj_file(vertices, filename)
|
74 |
+
|
75 |
+
if not success:
|
76 |
+
return None, None, "Failed to save OBJ file"
|
77 |
+
|
78 |
+
# Create 3D visualization
|
79 |
+
fig = create_3d_plot(vertices)
|
80 |
+
|
81 |
+
return str(obj_path), fig, "Model generated successfully"
|
82 |
+
|
83 |
+
def cleanup_old_files():
|
84 |
+
"""Remove files older than 1 hour"""
|
85 |
+
temp_dir = Path('temp')
|
86 |
+
if temp_dir.exists():
|
87 |
+
current_time = time.time()
|
88 |
+
for file in temp_dir.glob('*.obj'):
|
89 |
+
if current_time - file.stat().st_mtime > 3600: # 1 hour
|
90 |
+
file.unlink()
|
91 |
+
|
92 |
+
def create_interface():
|
93 |
+
"""Create the Gradio interface"""
|
94 |
+
with gr.Blocks() as interface:
|
95 |
+
gr.Markdown("""
|
96 |
+
# 3D Face Reconstruction
|
97 |
+
Upload an image or use your webcam to generate a 3D face model.
|
98 |
+
|
99 |
+
Instructions:
|
100 |
+
1. Upload an image or use the webcam
|
101 |
+
2. Click 'Generate 3D Model'
|
102 |
+
3. View the 3D model in the interactive viewer
|
103 |
+
4. Download the OBJ file for use in 3D software
|
104 |
+
""")
|
105 |
+
|
106 |
+
with gr.Row():
|
107 |
+
with gr.Column():
|
108 |
+
# Input methods
|
109 |
+
input_image = gr.Image(source="webcam", type="numpy")
|
110 |
+
generate_button = gr.Button("Generate 3D Model")
|
111 |
+
|
112 |
+
with gr.Column():
|
113 |
+
# Output displays
|
114 |
+
obj_file = gr.File(label="Download OBJ File")
|
115 |
+
plot_3d = gr.Plot(label="3D Preview")
|
116 |
+
status_text = gr.Textbox(label="Status")
|
117 |
+
|
118 |
+
# Set up event handler
|
119 |
+
generate_button.click(
|
120 |
+
fn=save_obj_and_display,
|
121 |
+
inputs=[input_image],
|
122 |
+
outputs=[obj_file, plot_3d, status_text]
|
123 |
+
)
|
124 |
+
|
125 |
+
# Cleanup old files periodically
|
126 |
+
cleanup_old_files()
|
127 |
+
|
128 |
+
return interface
|
129 |
+
|
130 |
+
# Create and launch the interface
|
131 |
+
interface = create_interface()
|
132 |
+
|
133 |
+
# Launch locally for testing
|
134 |
+
if __name__ == "__main__":
|
135 |
+
interface.launch()
|