Prashanthsrn commited on
Commit
e95699b
·
verified ·
1 Parent(s): 7cd7f83

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +133 -55
app.py CHANGED
@@ -1,57 +1,135 @@
 
1
  import gradio as gr
2
- import torch
 
3
  import numpy as np
4
- from PIL import Image
5
- import torchvision.transforms as transforms
6
- from decalib.deca import DECA
7
- from decalib.utils import util
8
- from decalib.utils.config import cfg as deca_cfg
9
-
10
- # Initialize DECA
11
- deca_cfg.model.use_tex = False
12
- deca = DECA(config=deca_cfg)
13
-
14
- def preprocess_image(image):
15
- transform = transforms.Compose([
16
- transforms.Resize(224),
17
- transforms.ToTensor(),
18
- transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
19
- ])
20
- return transform(image).unsqueeze(0)
21
-
22
- def reconstruct_face(image):
23
- # Preprocess the image
24
- input_image = preprocess_image(Image.fromarray(image))
25
-
26
- # Run DECA
27
- with torch.no_grad():
28
- codedict = deca.encode(input_image)
29
- opdict, visdict = deca.decode(codedict)
30
-
31
- # Get the reconstructed face
32
- reconstructed_face = util.tensor2image(visdict['shape_images'][0])
33
-
34
- # Get the 3D mesh
35
- vertices = opdict['vertices'][0].cpu().numpy()
36
- faces = deca.flame.faces_tensor.cpu().numpy()
37
-
38
- return reconstructed_face, vertices, faces
39
-
40
- def process_image(input_image):
41
- reconstructed_face, vertices, faces = reconstruct_face(input_image)
42
- return reconstructed_face, [vertices, faces]
43
-
44
- # Define the Gradio interface
45
- iface = gr.Interface(
46
- fn=process_image,
47
- inputs=gr.Image(),
48
- outputs=[
49
- gr.Image(label="Reconstructed Face"),
50
- gr.Model3D(label="3D Face Model")
51
- ],
52
- title="3D Face Reconstruction from a Single Image",
53
- description="Upload an image of a face to generate a 3D reconstruction."
54
- )
55
-
56
- # Launch the app
57
- iface.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app.py
2
  import gradio as gr
3
+ import cv2
4
+ import mediapipe as mp
5
  import numpy as np
6
+ import time
7
+ from pathlib import Path
8
+ import tempfile
9
+ import os
10
+ from face_reconstruction_main import FaceMeshDetector
11
+ import plotly.graph_objects as go
12
+ import json
13
+
14
+ # Initialize the detector
15
+ detector = FaceMeshDetector()
16
+
17
+ def process_frame(frame):
18
+ """Process a single frame and return the image with face mesh overlay"""
19
+ if frame is None:
20
+ return None
21
+
22
+ img, faces = detector.detect_faces(frame)
23
+ return img, faces
24
+
25
+ def create_3d_plot(vertices):
26
+ """Create a 3D plotly figure from vertices"""
27
+ x, y, z = zip(*vertices)
28
+
29
+ fig = go.Figure(data=[go.Scatter3d(
30
+ x=x, y=y, z=z,
31
+ mode='markers',
32
+ marker=dict(
33
+ size=2,
34
+ color=z,
35
+ colorscale='Viridis',
36
+ )
37
+ )])
38
+
39
+ fig.update_layout(
40
+ scene=dict(
41
+ xaxis_title='X',
42
+ yaxis_title='Y',
43
+ zaxis_title='Z'
44
+ ),
45
+ margin=dict(l=0, r=0, b=0, t=0)
46
+ )
47
+
48
+ return fig
49
+
50
+ def save_obj_and_display(frame):
51
+ """Save frame as OBJ and return both file path and 3D visualization"""
52
+ if frame is None:
53
+ return None, None, "No frame provided"
54
+
55
+ img, faces = process_frame(frame)
56
+
57
+ if not faces:
58
+ return None, None, "No face detected in the frame"
59
+
60
+ # Scale vertices
61
+ vertices = np.array(faces[0]) * 150
62
+
63
+ # Create temporary directory if it doesn't exist
64
+ temp_dir = Path('temp')
65
+ temp_dir.mkdir(exist_ok=True)
66
+
67
+ # Generate timestamp
68
+ timestamp = time.strftime("%Y%m%d-%H%M%S")
69
+ filename = f"face_mesh_{timestamp}"
70
+
71
+ # Save OBJ file
72
+ obj_path = temp_dir / f"{filename}.obj"
73
+ success = detector.save_obj_file(vertices, filename)
74
+
75
+ if not success:
76
+ return None, None, "Failed to save OBJ file"
77
+
78
+ # Create 3D visualization
79
+ fig = create_3d_plot(vertices)
80
+
81
+ return str(obj_path), fig, "Model generated successfully"
82
+
83
+ def cleanup_old_files():
84
+ """Remove files older than 1 hour"""
85
+ temp_dir = Path('temp')
86
+ if temp_dir.exists():
87
+ current_time = time.time()
88
+ for file in temp_dir.glob('*.obj'):
89
+ if current_time - file.stat().st_mtime > 3600: # 1 hour
90
+ file.unlink()
91
+
92
+ def create_interface():
93
+ """Create the Gradio interface"""
94
+ with gr.Blocks() as interface:
95
+ gr.Markdown("""
96
+ # 3D Face Reconstruction
97
+ Upload an image or use your webcam to generate a 3D face model.
98
+
99
+ Instructions:
100
+ 1. Upload an image or use the webcam
101
+ 2. Click 'Generate 3D Model'
102
+ 3. View the 3D model in the interactive viewer
103
+ 4. Download the OBJ file for use in 3D software
104
+ """)
105
+
106
+ with gr.Row():
107
+ with gr.Column():
108
+ # Input methods
109
+ input_image = gr.Image(source="webcam", type="numpy")
110
+ generate_button = gr.Button("Generate 3D Model")
111
+
112
+ with gr.Column():
113
+ # Output displays
114
+ obj_file = gr.File(label="Download OBJ File")
115
+ plot_3d = gr.Plot(label="3D Preview")
116
+ status_text = gr.Textbox(label="Status")
117
+
118
+ # Set up event handler
119
+ generate_button.click(
120
+ fn=save_obj_and_display,
121
+ inputs=[input_image],
122
+ outputs=[obj_file, plot_3d, status_text]
123
+ )
124
+
125
+ # Cleanup old files periodically
126
+ cleanup_old_files()
127
+
128
+ return interface
129
+
130
+ # Create and launch the interface
131
+ interface = create_interface()
132
+
133
+ # Launch locally for testing
134
+ if __name__ == "__main__":
135
+ interface.launch()