joselobenitezg commited on
Commit
b05da2d
·
1 Parent(s): d04cd0a

add pose inference

Browse files
Files changed (3) hide show
  1. app.py +10 -3
  2. inference/pose.py +6 -2
  3. load_and_test.ipynb +23 -21
app.py CHANGED
@@ -5,19 +5,26 @@ from PIL import Image
5
  import cv2
6
  import spaces
7
 
8
- from inference.seg import process_image_or_video
 
9
  from config import SAPIENS_LITE_MODELS_PATH
10
 
11
  def update_model_choices(task):
12
  model_choices = list(SAPIENS_LITE_MODELS_PATH[task.lower()].keys())
13
  return gr.Dropdown(choices=model_choices, value=model_choices[0] if model_choices else None)
14
 
15
- @spaces.GPU(duration=120)
16
  def process_image(input_image, task, version):
17
  if isinstance(input_image, np.ndarray):
18
  input_image = Image.fromarray(input_image)
19
 
20
- result = process_image_or_video(input_image, task=task.lower(), version=version)
 
 
 
 
 
 
21
 
22
  return result
23
 
 
5
  import cv2
6
  import spaces
7
 
8
+ from inference.seg import process_image_or_video as process_seg
9
+ from inference.pose import process_image_or_video as process_pose
10
  from config import SAPIENS_LITE_MODELS_PATH
11
 
12
  def update_model_choices(task):
13
  model_choices = list(SAPIENS_LITE_MODELS_PATH[task.lower()].keys())
14
  return gr.Dropdown(choices=model_choices, value=model_choices[0] if model_choices else None)
15
 
16
+ @spaces.GPU()
17
  def process_image(input_image, task, version):
18
  if isinstance(input_image, np.ndarray):
19
  input_image = Image.fromarray(input_image)
20
 
21
+ if task.lower() == 'seg':
22
+ result = process_seg(input_image, task=task.lower(), version=version)
23
+ elif task.lower() == 'pose':
24
+ result = process_pose(input_image, task=task.lower(), version=version)
25
+ else:
26
+ result = None
27
+ print(f"Tarea no soportada: {task}")
28
 
29
  return result
30
 
inference/pose.py CHANGED
@@ -90,6 +90,9 @@ def load_model(task, version):
90
  try:
91
  model_path = SAPIENS_LITE_MODELS_PATH[task][version]
92
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
 
 
93
  model = torch.jit.load(model_path)
94
  model.eval()
95
  model.to(device)
@@ -109,6 +112,7 @@ def preprocess_image(image, input_shape):
109
  return img.unsqueeze(0)
110
 
111
  def udp_decode(heatmap, img_size, heatmap_size):
 
112
  h, w = heatmap_size
113
  keypoints = np.zeros((heatmap.shape[0], 2))
114
  keypoint_scores = np.zeros(heatmap.shape[0])
@@ -133,8 +137,6 @@ def process_image_or_video(input_data, task='pose', version='sapiens_1b'):
133
  if model is None or device is None:
134
  return None
135
 
136
- input_shape = (3, 1024, 768)
137
-
138
  def process_frame(frame):
139
  if isinstance(frame, np.ndarray):
140
  frame = Image.fromarray(frame)
@@ -142,6 +144,8 @@ def process_image_or_video(input_data, task='pose', version='sapiens_1b'):
142
  if frame.mode == 'RGBA':
143
  frame = frame.convert('RGB')
144
 
 
 
145
  img = preprocess_image(frame, input_shape)
146
 
147
  with torch.no_grad():
 
90
  try:
91
  model_path = SAPIENS_LITE_MODELS_PATH[task][version]
92
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
93
+ if torch.cuda.is_available() and torch.cuda.get_device_properties(0).major >= 8:
94
+ torch.backends.cuda.matmul.allow_tf32 = True
95
+ torch.backends.cudnn.allow_tf32 = True
96
  model = torch.jit.load(model_path)
97
  model.eval()
98
  model.to(device)
 
112
  return img.unsqueeze(0)
113
 
114
  def udp_decode(heatmap, img_size, heatmap_size):
115
+ # This is a simplified version. You might need to implement the full UDP decode logic
116
  h, w = heatmap_size
117
  keypoints = np.zeros((heatmap.shape[0], 2))
118
  keypoint_scores = np.zeros(heatmap.shape[0])
 
137
  if model is None or device is None:
138
  return None
139
 
 
 
140
  def process_frame(frame):
141
  if isinstance(frame, np.ndarray):
142
  frame = Image.fromarray(frame)
 
144
  if frame.mode == 'RGBA':
145
  frame = frame.convert('RGB')
146
 
147
+ input_shape = (3, frame.height, frame.width)
148
+
149
  img = preprocess_image(frame, input_shape)
150
 
151
  with torch.no_grad():
load_and_test.ipynb CHANGED
@@ -2,7 +2,7 @@
2
  "cells": [
3
  {
4
  "cell_type": "code",
5
- "execution_count": 2,
6
  "metadata": {},
7
  "outputs": [
8
  {
@@ -32,7 +32,7 @@
32
  },
33
  {
34
  "cell_type": "code",
35
- "execution_count": 3,
36
  "metadata": {},
37
  "outputs": [],
38
  "source": [
@@ -77,7 +77,7 @@
77
  },
78
  {
79
  "cell_type": "code",
80
- "execution_count": 4,
81
  "metadata": {},
82
  "outputs": [
83
  {
@@ -829,7 +829,7 @@
829
  ")"
830
  ]
831
  },
832
- "execution_count": 4,
833
  "metadata": {},
834
  "output_type": "execute_result"
835
  }
@@ -842,7 +842,7 @@
842
  },
843
  {
844
  "cell_type": "code",
845
- "execution_count": 5,
846
  "metadata": {},
847
  "outputs": [],
848
  "source": [
@@ -856,7 +856,7 @@
856
  },
857
  {
858
  "cell_type": "code",
859
- "execution_count": 6,
860
  "metadata": {},
861
  "outputs": [],
862
  "source": [
@@ -871,7 +871,7 @@
871
  },
872
  {
873
  "cell_type": "code",
874
- "execution_count": 7,
875
  "metadata": {},
876
  "outputs": [],
877
  "source": [
@@ -888,7 +888,7 @@
888
  },
889
  {
890
  "cell_type": "code",
891
- "execution_count": 8,
892
  "metadata": {},
893
  "outputs": [
894
  {
@@ -899,7 +899,7 @@
899
  "<PIL.Image.Image image mode=RGB size=640x480>"
900
  ]
901
  },
902
- "execution_count": 8,
903
  "metadata": {},
904
  "output_type": "execute_result"
905
  }
@@ -917,7 +917,7 @@
917
  },
918
  {
919
  "cell_type": "code",
920
- "execution_count": 9,
921
  "metadata": {},
922
  "outputs": [],
923
  "source": [
@@ -926,7 +926,7 @@
926
  },
927
  {
928
  "cell_type": "code",
929
- "execution_count": 10,
930
  "metadata": {},
931
  "outputs": [
932
  {
@@ -937,7 +937,7 @@
937
  "<PIL.Image.Image image mode=RGB size=1024x1024>"
938
  ]
939
  },
940
- "execution_count": 10,
941
  "metadata": {},
942
  "output_type": "execute_result"
943
  }
@@ -955,7 +955,7 @@
955
  },
956
  {
957
  "cell_type": "code",
958
- "execution_count": 5,
959
  "metadata": {},
960
  "outputs": [
961
  {
@@ -977,7 +977,7 @@
977
  },
978
  {
979
  "cell_type": "code",
980
- "execution_count": 6,
981
  "metadata": {},
982
  "outputs": [
983
  {
@@ -2188,7 +2188,7 @@
2188
  ")"
2189
  ]
2190
  },
2191
- "execution_count": 6,
2192
  "metadata": {},
2193
  "output_type": "execute_result"
2194
  }
@@ -2230,7 +2230,7 @@
2230
  },
2231
  {
2232
  "cell_type": "code",
2233
- "execution_count": 67,
2234
  "metadata": {},
2235
  "outputs": [],
2236
  "source": [
@@ -2292,7 +2292,7 @@
2292
  },
2293
  {
2294
  "cell_type": "code",
2295
- "execution_count": 68,
2296
  "metadata": {},
2297
  "outputs": [
2298
  {
@@ -2303,7 +2303,7 @@
2303
  "<PIL.Image.Image image mode=RGB size=640x480>"
2304
  ]
2305
  },
2306
- "execution_count": 68,
2307
  "metadata": {},
2308
  "output_type": "execute_result"
2309
  }
@@ -2321,16 +2321,18 @@
2321
  },
2322
  {
2323
  "cell_type": "code",
2324
- "execution_count": 69,
2325
  "metadata": {},
2326
  "outputs": [],
2327
  "source": [
 
 
2328
  "output_pose = get_pose(resized_pil_image, model)"
2329
  ]
2330
  },
2331
  {
2332
  "cell_type": "code",
2333
- "execution_count": 70,
2334
  "metadata": {},
2335
  "outputs": [
2336
  {
@@ -2341,7 +2343,7 @@
2341
  "<PIL.Image.Image image mode=RGB size=640x480>"
2342
  ]
2343
  },
2344
- "execution_count": 70,
2345
  "metadata": {},
2346
  "output_type": "execute_result"
2347
  }
 
2
  "cells": [
3
  {
4
  "cell_type": "code",
5
+ "execution_count": 1,
6
  "metadata": {},
7
  "outputs": [
8
  {
 
32
  },
33
  {
34
  "cell_type": "code",
35
+ "execution_count": 2,
36
  "metadata": {},
37
  "outputs": [],
38
  "source": [
 
77
  },
78
  {
79
  "cell_type": "code",
80
+ "execution_count": 5,
81
  "metadata": {},
82
  "outputs": [
83
  {
 
829
  ")"
830
  ]
831
  },
832
+ "execution_count": 5,
833
  "metadata": {},
834
  "output_type": "execute_result"
835
  }
 
842
  },
843
  {
844
  "cell_type": "code",
845
+ "execution_count": 6,
846
  "metadata": {},
847
  "outputs": [],
848
  "source": [
 
856
  },
857
  {
858
  "cell_type": "code",
859
+ "execution_count": 7,
860
  "metadata": {},
861
  "outputs": [],
862
  "source": [
 
871
  },
872
  {
873
  "cell_type": "code",
874
+ "execution_count": 8,
875
  "metadata": {},
876
  "outputs": [],
877
  "source": [
 
888
  },
889
  {
890
  "cell_type": "code",
891
+ "execution_count": 9,
892
  "metadata": {},
893
  "outputs": [
894
  {
 
899
  "<PIL.Image.Image image mode=RGB size=640x480>"
900
  ]
901
  },
902
+ "execution_count": 9,
903
  "metadata": {},
904
  "output_type": "execute_result"
905
  }
 
917
  },
918
  {
919
  "cell_type": "code",
920
+ "execution_count": 10,
921
  "metadata": {},
922
  "outputs": [],
923
  "source": [
 
926
  },
927
  {
928
  "cell_type": "code",
929
+ "execution_count": 11,
930
  "metadata": {},
931
  "outputs": [
932
  {
 
937
  "<PIL.Image.Image image mode=RGB size=1024x1024>"
938
  ]
939
  },
940
+ "execution_count": 11,
941
  "metadata": {},
942
  "output_type": "execute_result"
943
  }
 
955
  },
956
  {
957
  "cell_type": "code",
958
+ "execution_count": 13,
959
  "metadata": {},
960
  "outputs": [
961
  {
 
977
  },
978
  {
979
  "cell_type": "code",
980
+ "execution_count": 14,
981
  "metadata": {},
982
  "outputs": [
983
  {
 
2188
  ")"
2189
  ]
2190
  },
2191
+ "execution_count": 14,
2192
  "metadata": {},
2193
  "output_type": "execute_result"
2194
  }
 
2230
  },
2231
  {
2232
  "cell_type": "code",
2233
+ "execution_count": 15,
2234
  "metadata": {},
2235
  "outputs": [],
2236
  "source": [
 
2292
  },
2293
  {
2294
  "cell_type": "code",
2295
+ "execution_count": 16,
2296
  "metadata": {},
2297
  "outputs": [
2298
  {
 
2303
  "<PIL.Image.Image image mode=RGB size=640x480>"
2304
  ]
2305
  },
2306
+ "execution_count": 16,
2307
  "metadata": {},
2308
  "output_type": "execute_result"
2309
  }
 
2321
  },
2322
  {
2323
  "cell_type": "code",
2324
+ "execution_count": 18,
2325
  "metadata": {},
2326
  "outputs": [],
2327
  "source": [
2328
+ "from PIL import Image, ImageDraw\n",
2329
+ "\n",
2330
  "output_pose = get_pose(resized_pil_image, model)"
2331
  ]
2332
  },
2333
  {
2334
  "cell_type": "code",
2335
+ "execution_count": 20,
2336
  "metadata": {},
2337
  "outputs": [
2338
  {
 
2343
  "<PIL.Image.Image image mode=RGB size=640x480>"
2344
  ]
2345
  },
2346
+ "execution_count": 20,
2347
  "metadata": {},
2348
  "output_type": "execute_result"
2349
  }