phuochungus commited on
Commit
1fd5f60
1 Parent(s): febdf5d

Refactor video handling and add user artifact association

Browse files
Files changed (1) hide show
  1. app/routers/video.py +18 -12
app/routers/video.py CHANGED
@@ -9,7 +9,9 @@ import cv2
9
  from multiprocessing import Process
10
  from fastapi import (
11
  APIRouter,
 
12
  HTTPException,
 
13
  UploadFile,
14
  BackgroundTasks,
15
  status,
@@ -18,8 +20,9 @@ from firebase_admin import messaging
18
  from app import db
19
  from app import supabase
20
  from app.dependencies import get_current_user
21
- from app.routers.image import inference_image
22
  from google.cloud.firestore_v1.base_query import FieldFilter
 
23
  from app import logger
24
 
25
  router = APIRouter(prefix="/video", tags=["Video"])
@@ -30,6 +33,7 @@ async def handleVideoRequest(
30
  file: UploadFile,
31
  background_tasks: BackgroundTasks,
32
  threshold: float = 0.3,
 
33
  ):
34
  if re.search("^video\/", file.content_type) is None:
35
  raise HTTPException(
@@ -42,11 +46,12 @@ async def handleVideoRequest(
42
  _, artifact_ref = db.collection("artifacts").add(
43
  {"name": id + ".mp4", "status": "pending"}
44
  )
 
45
  os.mkdir(id)
46
  async with aiofiles.open(os.path.join(id, "input.mp4"), "wb") as out_file:
47
- while content := await file.read(1024):
48
  await out_file.write(content)
49
- await inference_video(artifact_ref.id, id, threshold)
50
  return id + ".mp4"
51
  except ValueError as err:
52
  logger.error(err)
@@ -64,7 +69,7 @@ def createThumbnail(thumbnail, inputDir):
64
  cv2.imwrite(os.path.join(inputDir, "thumbnail.jpg"), thumbnail)
65
 
66
 
67
- def inference_frame(inputDir, threshold: float = 0.3):
68
  cap = cv2.VideoCapture(
69
  filename=os.path.join(inputDir, "input.mp4"), apiPreference=cv2.CAP_FFMPEG
70
  )
@@ -91,7 +96,7 @@ def inference_frame(inputDir, threshold: float = 0.3):
91
  if res == False:
92
  break
93
 
94
- resFram = inference_image(frame, threshold)
95
  result.write(resFram)
96
  cap.release()
97
  result.release()
@@ -100,10 +105,11 @@ def inference_frame(inputDir, threshold: float = 0.3):
100
  return thumbnail
101
 
102
 
103
- async def inference_video(artifactId: str, inputDir: str, threshold: float):
 
104
  try:
105
- Process(update_artifact(artifactId, {"status": "processing"})).start()
106
- thumbnail = inference_frame(inputDir, threshold=threshold)
107
  createThumbnail(thumbnail, inputDir)
108
 
109
  async def uploadVideo():
@@ -125,9 +131,9 @@ async def inference_video(artifactId: str, inputDir: str, threshold: float):
125
  _, _ = await asyncio.gather(uploadVideo(), uploadThumbnail())
126
  print(now() - n)
127
  except Exception as e:
128
- print(e)
129
 
130
- update_artifact(
131
  artifactId,
132
  {
133
  "status": "success",
@@ -141,7 +147,7 @@ async def inference_video(artifactId: str, inputDir: str, threshold: float):
141
  )
142
  except:
143
  Process(
144
- update_artifact(
145
  artifactId,
146
  {
147
  "status": "fail",
@@ -155,7 +161,7 @@ async def inference_video(artifactId: str, inputDir: str, threshold: float):
155
  print(e)
156
 
157
 
158
- def update_artifact(artifactId: str, body):
159
  artifact_ref = db.collection("artifacts").document(artifactId)
160
  artifact_snapshot = artifact_ref.get()
161
  if artifact_snapshot.exists:
 
9
  from multiprocessing import Process
10
  from fastapi import (
11
  APIRouter,
12
+ Depends,
13
  HTTPException,
14
+ Request,
15
  UploadFile,
16
  BackgroundTasks,
17
  status,
 
20
  from app import db
21
  from app import supabase
22
  from app.dependencies import get_current_user
23
+ from app.routers.image import inferenceImage
24
  from google.cloud.firestore_v1.base_query import FieldFilter
25
+ from google.cloud.firestore import ArrayUnion
26
  from app import logger
27
 
28
  router = APIRouter(prefix="/video", tags=["Video"])
 
33
  file: UploadFile,
34
  background_tasks: BackgroundTasks,
35
  threshold: float = 0.3,
36
+ user=Depends(get_current_user)
37
  ):
38
  if re.search("^video\/", file.content_type) is None:
39
  raise HTTPException(
 
46
  _, artifact_ref = db.collection("artifacts").add(
47
  {"name": id + ".mp4", "status": "pending"}
48
  )
49
+ db.collection("user").document(user["sub"]).update({"artifacts": ArrayUnion(['artifact/' + artifact_ref.id])})
50
  os.mkdir(id)
51
  async with aiofiles.open(os.path.join(id, "input.mp4"), "wb") as out_file:
52
+ while content := await file.read(102400):
53
  await out_file.write(content)
54
+ background_tasks.add_task(inferenceVideo, artifact_ref.id, id, threshold)
55
  return id + ".mp4"
56
  except ValueError as err:
57
  logger.error(err)
 
69
  cv2.imwrite(os.path.join(inputDir, "thumbnail.jpg"), thumbnail)
70
 
71
 
72
+ def inferenceFrame(inputDir, threshold: float = 0.3):
73
  cap = cv2.VideoCapture(
74
  filename=os.path.join(inputDir, "input.mp4"), apiPreference=cv2.CAP_FFMPEG
75
  )
 
96
  if res == False:
97
  break
98
 
99
+ resFram = inferenceImage(frame, threshold, False)
100
  result.write(resFram)
101
  cap.release()
102
  result.release()
 
105
  return thumbnail
106
 
107
 
108
+ async def inferenceVideo(artifactId: str, inputDir: str, threshold: float):
109
+ logger.info("Start inference video")
110
  try:
111
+ Process(updateArtifact(artifactId, {"status": "processing"})).start()
112
+ thumbnail = inferenceFrame(inputDir, threshold=threshold)
113
  createThumbnail(thumbnail, inputDir)
114
 
115
  async def uploadVideo():
 
131
  _, _ = await asyncio.gather(uploadVideo(), uploadThumbnail())
132
  print(now() - n)
133
  except Exception as e:
134
+ logger.error(e)
135
 
136
+ updateArtifact(
137
  artifactId,
138
  {
139
  "status": "success",
 
147
  )
148
  except:
149
  Process(
150
+ updateArtifact(
151
  artifactId,
152
  {
153
  "status": "fail",
 
161
  print(e)
162
 
163
 
164
+ def updateArtifact(artifactId: str, body):
165
  artifact_ref = db.collection("artifacts").document(artifactId)
166
  artifact_snapshot = artifact_ref.get()
167
  if artifact_snapshot.exists: