Spaces:
Running
Running
joselobenitezg
commited on
Commit
·
b05da2d
1
Parent(s):
d04cd0a
add pose inference
Browse files- app.py +10 -3
- inference/pose.py +6 -2
- load_and_test.ipynb +23 -21
app.py
CHANGED
@@ -5,19 +5,26 @@ from PIL import Image
|
|
5 |
import cv2
|
6 |
import spaces
|
7 |
|
8 |
-
from inference.seg import process_image_or_video
|
|
|
9 |
from config import SAPIENS_LITE_MODELS_PATH
|
10 |
|
11 |
def update_model_choices(task):
|
12 |
model_choices = list(SAPIENS_LITE_MODELS_PATH[task.lower()].keys())
|
13 |
return gr.Dropdown(choices=model_choices, value=model_choices[0] if model_choices else None)
|
14 |
|
15 |
-
@spaces.GPU(
|
16 |
def process_image(input_image, task, version):
|
17 |
if isinstance(input_image, np.ndarray):
|
18 |
input_image = Image.fromarray(input_image)
|
19 |
|
20 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
21 |
|
22 |
return result
|
23 |
|
|
|
5 |
import cv2
|
6 |
import spaces
|
7 |
|
8 |
+
from inference.seg import process_image_or_video as process_seg
|
9 |
+
from inference.pose import process_image_or_video as process_pose
|
10 |
from config import SAPIENS_LITE_MODELS_PATH
|
11 |
|
12 |
def update_model_choices(task):
|
13 |
model_choices = list(SAPIENS_LITE_MODELS_PATH[task.lower()].keys())
|
14 |
return gr.Dropdown(choices=model_choices, value=model_choices[0] if model_choices else None)
|
15 |
|
16 |
+
@spaces.GPU()
|
17 |
def process_image(input_image, task, version):
|
18 |
if isinstance(input_image, np.ndarray):
|
19 |
input_image = Image.fromarray(input_image)
|
20 |
|
21 |
+
if task.lower() == 'seg':
|
22 |
+
result = process_seg(input_image, task=task.lower(), version=version)
|
23 |
+
elif task.lower() == 'pose':
|
24 |
+
result = process_pose(input_image, task=task.lower(), version=version)
|
25 |
+
else:
|
26 |
+
result = None
|
27 |
+
print(f"Tarea no soportada: {task}")
|
28 |
|
29 |
return result
|
30 |
|
inference/pose.py
CHANGED
@@ -90,6 +90,9 @@ def load_model(task, version):
|
|
90 |
try:
|
91 |
model_path = SAPIENS_LITE_MODELS_PATH[task][version]
|
92 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
|
|
|
|
93 |
model = torch.jit.load(model_path)
|
94 |
model.eval()
|
95 |
model.to(device)
|
@@ -109,6 +112,7 @@ def preprocess_image(image, input_shape):
|
|
109 |
return img.unsqueeze(0)
|
110 |
|
111 |
def udp_decode(heatmap, img_size, heatmap_size):
|
|
|
112 |
h, w = heatmap_size
|
113 |
keypoints = np.zeros((heatmap.shape[0], 2))
|
114 |
keypoint_scores = np.zeros(heatmap.shape[0])
|
@@ -133,8 +137,6 @@ def process_image_or_video(input_data, task='pose', version='sapiens_1b'):
|
|
133 |
if model is None or device is None:
|
134 |
return None
|
135 |
|
136 |
-
input_shape = (3, 1024, 768)
|
137 |
-
|
138 |
def process_frame(frame):
|
139 |
if isinstance(frame, np.ndarray):
|
140 |
frame = Image.fromarray(frame)
|
@@ -142,6 +144,8 @@ def process_image_or_video(input_data, task='pose', version='sapiens_1b'):
|
|
142 |
if frame.mode == 'RGBA':
|
143 |
frame = frame.convert('RGB')
|
144 |
|
|
|
|
|
145 |
img = preprocess_image(frame, input_shape)
|
146 |
|
147 |
with torch.no_grad():
|
|
|
90 |
try:
|
91 |
model_path = SAPIENS_LITE_MODELS_PATH[task][version]
|
92 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
93 |
+
if torch.cuda.is_available() and torch.cuda.get_device_properties(0).major >= 8:
|
94 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
95 |
+
torch.backends.cudnn.allow_tf32 = True
|
96 |
model = torch.jit.load(model_path)
|
97 |
model.eval()
|
98 |
model.to(device)
|
|
|
112 |
return img.unsqueeze(0)
|
113 |
|
114 |
def udp_decode(heatmap, img_size, heatmap_size):
|
115 |
+
# This is a simplified version. You might need to implement the full UDP decode logic
|
116 |
h, w = heatmap_size
|
117 |
keypoints = np.zeros((heatmap.shape[0], 2))
|
118 |
keypoint_scores = np.zeros(heatmap.shape[0])
|
|
|
137 |
if model is None or device is None:
|
138 |
return None
|
139 |
|
|
|
|
|
140 |
def process_frame(frame):
|
141 |
if isinstance(frame, np.ndarray):
|
142 |
frame = Image.fromarray(frame)
|
|
|
144 |
if frame.mode == 'RGBA':
|
145 |
frame = frame.convert('RGB')
|
146 |
|
147 |
+
input_shape = (3, frame.height, frame.width)
|
148 |
+
|
149 |
img = preprocess_image(frame, input_shape)
|
150 |
|
151 |
with torch.no_grad():
|
load_and_test.ipynb
CHANGED
@@ -2,7 +2,7 @@
|
|
2 |
"cells": [
|
3 |
{
|
4 |
"cell_type": "code",
|
5 |
-
"execution_count":
|
6 |
"metadata": {},
|
7 |
"outputs": [
|
8 |
{
|
@@ -32,7 +32,7 @@
|
|
32 |
},
|
33 |
{
|
34 |
"cell_type": "code",
|
35 |
-
"execution_count":
|
36 |
"metadata": {},
|
37 |
"outputs": [],
|
38 |
"source": [
|
@@ -77,7 +77,7 @@
|
|
77 |
},
|
78 |
{
|
79 |
"cell_type": "code",
|
80 |
-
"execution_count":
|
81 |
"metadata": {},
|
82 |
"outputs": [
|
83 |
{
|
@@ -829,7 +829,7 @@
|
|
829 |
")"
|
830 |
]
|
831 |
},
|
832 |
-
"execution_count":
|
833 |
"metadata": {},
|
834 |
"output_type": "execute_result"
|
835 |
}
|
@@ -842,7 +842,7 @@
|
|
842 |
},
|
843 |
{
|
844 |
"cell_type": "code",
|
845 |
-
"execution_count":
|
846 |
"metadata": {},
|
847 |
"outputs": [],
|
848 |
"source": [
|
@@ -856,7 +856,7 @@
|
|
856 |
},
|
857 |
{
|
858 |
"cell_type": "code",
|
859 |
-
"execution_count":
|
860 |
"metadata": {},
|
861 |
"outputs": [],
|
862 |
"source": [
|
@@ -871,7 +871,7 @@
|
|
871 |
},
|
872 |
{
|
873 |
"cell_type": "code",
|
874 |
-
"execution_count":
|
875 |
"metadata": {},
|
876 |
"outputs": [],
|
877 |
"source": [
|
@@ -888,7 +888,7 @@
|
|
888 |
},
|
889 |
{
|
890 |
"cell_type": "code",
|
891 |
-
"execution_count":
|
892 |
"metadata": {},
|
893 |
"outputs": [
|
894 |
{
|
@@ -899,7 +899,7 @@
|
|
899 |
"<PIL.Image.Image image mode=RGB size=640x480>"
|
900 |
]
|
901 |
},
|
902 |
-
"execution_count":
|
903 |
"metadata": {},
|
904 |
"output_type": "execute_result"
|
905 |
}
|
@@ -917,7 +917,7 @@
|
|
917 |
},
|
918 |
{
|
919 |
"cell_type": "code",
|
920 |
-
"execution_count":
|
921 |
"metadata": {},
|
922 |
"outputs": [],
|
923 |
"source": [
|
@@ -926,7 +926,7 @@
|
|
926 |
},
|
927 |
{
|
928 |
"cell_type": "code",
|
929 |
-
"execution_count":
|
930 |
"metadata": {},
|
931 |
"outputs": [
|
932 |
{
|
@@ -937,7 +937,7 @@
|
|
937 |
"<PIL.Image.Image image mode=RGB size=1024x1024>"
|
938 |
]
|
939 |
},
|
940 |
-
"execution_count":
|
941 |
"metadata": {},
|
942 |
"output_type": "execute_result"
|
943 |
}
|
@@ -955,7 +955,7 @@
|
|
955 |
},
|
956 |
{
|
957 |
"cell_type": "code",
|
958 |
-
"execution_count":
|
959 |
"metadata": {},
|
960 |
"outputs": [
|
961 |
{
|
@@ -977,7 +977,7 @@
|
|
977 |
},
|
978 |
{
|
979 |
"cell_type": "code",
|
980 |
-
"execution_count":
|
981 |
"metadata": {},
|
982 |
"outputs": [
|
983 |
{
|
@@ -2188,7 +2188,7 @@
|
|
2188 |
")"
|
2189 |
]
|
2190 |
},
|
2191 |
-
"execution_count":
|
2192 |
"metadata": {},
|
2193 |
"output_type": "execute_result"
|
2194 |
}
|
@@ -2230,7 +2230,7 @@
|
|
2230 |
},
|
2231 |
{
|
2232 |
"cell_type": "code",
|
2233 |
-
"execution_count":
|
2234 |
"metadata": {},
|
2235 |
"outputs": [],
|
2236 |
"source": [
|
@@ -2292,7 +2292,7 @@
|
|
2292 |
},
|
2293 |
{
|
2294 |
"cell_type": "code",
|
2295 |
-
"execution_count":
|
2296 |
"metadata": {},
|
2297 |
"outputs": [
|
2298 |
{
|
@@ -2303,7 +2303,7 @@
|
|
2303 |
"<PIL.Image.Image image mode=RGB size=640x480>"
|
2304 |
]
|
2305 |
},
|
2306 |
-
"execution_count":
|
2307 |
"metadata": {},
|
2308 |
"output_type": "execute_result"
|
2309 |
}
|
@@ -2321,16 +2321,18 @@
|
|
2321 |
},
|
2322 |
{
|
2323 |
"cell_type": "code",
|
2324 |
-
"execution_count":
|
2325 |
"metadata": {},
|
2326 |
"outputs": [],
|
2327 |
"source": [
|
|
|
|
|
2328 |
"output_pose = get_pose(resized_pil_image, model)"
|
2329 |
]
|
2330 |
},
|
2331 |
{
|
2332 |
"cell_type": "code",
|
2333 |
-
"execution_count":
|
2334 |
"metadata": {},
|
2335 |
"outputs": [
|
2336 |
{
|
@@ -2341,7 +2343,7 @@
|
|
2341 |
"<PIL.Image.Image image mode=RGB size=640x480>"
|
2342 |
]
|
2343 |
},
|
2344 |
-
"execution_count":
|
2345 |
"metadata": {},
|
2346 |
"output_type": "execute_result"
|
2347 |
}
|
|
|
2 |
"cells": [
|
3 |
{
|
4 |
"cell_type": "code",
|
5 |
+
"execution_count": 1,
|
6 |
"metadata": {},
|
7 |
"outputs": [
|
8 |
{
|
|
|
32 |
},
|
33 |
{
|
34 |
"cell_type": "code",
|
35 |
+
"execution_count": 2,
|
36 |
"metadata": {},
|
37 |
"outputs": [],
|
38 |
"source": [
|
|
|
77 |
},
|
78 |
{
|
79 |
"cell_type": "code",
|
80 |
+
"execution_count": 5,
|
81 |
"metadata": {},
|
82 |
"outputs": [
|
83 |
{
|
|
|
829 |
")"
|
830 |
]
|
831 |
},
|
832 |
+
"execution_count": 5,
|
833 |
"metadata": {},
|
834 |
"output_type": "execute_result"
|
835 |
}
|
|
|
842 |
},
|
843 |
{
|
844 |
"cell_type": "code",
|
845 |
+
"execution_count": 6,
|
846 |
"metadata": {},
|
847 |
"outputs": [],
|
848 |
"source": [
|
|
|
856 |
},
|
857 |
{
|
858 |
"cell_type": "code",
|
859 |
+
"execution_count": 7,
|
860 |
"metadata": {},
|
861 |
"outputs": [],
|
862 |
"source": [
|
|
|
871 |
},
|
872 |
{
|
873 |
"cell_type": "code",
|
874 |
+
"execution_count": 8,
|
875 |
"metadata": {},
|
876 |
"outputs": [],
|
877 |
"source": [
|
|
|
888 |
},
|
889 |
{
|
890 |
"cell_type": "code",
|
891 |
+
"execution_count": 9,
|
892 |
"metadata": {},
|
893 |
"outputs": [
|
894 |
{
|
|
|
899 |
"<PIL.Image.Image image mode=RGB size=640x480>"
|
900 |
]
|
901 |
},
|
902 |
+
"execution_count": 9,
|
903 |
"metadata": {},
|
904 |
"output_type": "execute_result"
|
905 |
}
|
|
|
917 |
},
|
918 |
{
|
919 |
"cell_type": "code",
|
920 |
+
"execution_count": 10,
|
921 |
"metadata": {},
|
922 |
"outputs": [],
|
923 |
"source": [
|
|
|
926 |
},
|
927 |
{
|
928 |
"cell_type": "code",
|
929 |
+
"execution_count": 11,
|
930 |
"metadata": {},
|
931 |
"outputs": [
|
932 |
{
|
|
|
937 |
"<PIL.Image.Image image mode=RGB size=1024x1024>"
|
938 |
]
|
939 |
},
|
940 |
+
"execution_count": 11,
|
941 |
"metadata": {},
|
942 |
"output_type": "execute_result"
|
943 |
}
|
|
|
955 |
},
|
956 |
{
|
957 |
"cell_type": "code",
|
958 |
+
"execution_count": 13,
|
959 |
"metadata": {},
|
960 |
"outputs": [
|
961 |
{
|
|
|
977 |
},
|
978 |
{
|
979 |
"cell_type": "code",
|
980 |
+
"execution_count": 14,
|
981 |
"metadata": {},
|
982 |
"outputs": [
|
983 |
{
|
|
|
2188 |
")"
|
2189 |
]
|
2190 |
},
|
2191 |
+
"execution_count": 14,
|
2192 |
"metadata": {},
|
2193 |
"output_type": "execute_result"
|
2194 |
}
|
|
|
2230 |
},
|
2231 |
{
|
2232 |
"cell_type": "code",
|
2233 |
+
"execution_count": 15,
|
2234 |
"metadata": {},
|
2235 |
"outputs": [],
|
2236 |
"source": [
|
|
|
2292 |
},
|
2293 |
{
|
2294 |
"cell_type": "code",
|
2295 |
+
"execution_count": 16,
|
2296 |
"metadata": {},
|
2297 |
"outputs": [
|
2298 |
{
|
|
|
2303 |
"<PIL.Image.Image image mode=RGB size=640x480>"
|
2304 |
]
|
2305 |
},
|
2306 |
+
"execution_count": 16,
|
2307 |
"metadata": {},
|
2308 |
"output_type": "execute_result"
|
2309 |
}
|
|
|
2321 |
},
|
2322 |
{
|
2323 |
"cell_type": "code",
|
2324 |
+
"execution_count": 18,
|
2325 |
"metadata": {},
|
2326 |
"outputs": [],
|
2327 |
"source": [
|
2328 |
+
"from PIL import Image, ImageDraw\n",
|
2329 |
+
"\n",
|
2330 |
"output_pose = get_pose(resized_pil_image, model)"
|
2331 |
]
|
2332 |
},
|
2333 |
{
|
2334 |
"cell_type": "code",
|
2335 |
+
"execution_count": 20,
|
2336 |
"metadata": {},
|
2337 |
"outputs": [
|
2338 |
{
|
|
|
2343 |
"<PIL.Image.Image image mode=RGB size=640x480>"
|
2344 |
]
|
2345 |
},
|
2346 |
+
"execution_count": 20,
|
2347 |
"metadata": {},
|
2348 |
"output_type": "execute_result"
|
2349 |
}
|