Spaces:
Runtime error
Runtime error
Delete pipeline.py
Browse files- pipeline.py +0 -73
pipeline.py
DELETED
@@ -1,73 +0,0 @@
|
|
1 |
-
#! /usr/bin/env python
|
2 |
-
# -*- coding: utf-8 -*-
|
3 |
-
|
4 |
-
# Copyright 2023 Imperial College London (Pingchuan Ma)
|
5 |
-
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
6 |
-
|
7 |
-
import os
|
8 |
-
import torch
|
9 |
-
import pickle
|
10 |
-
from configparser import ConfigParser
|
11 |
-
|
12 |
-
from pipelines.model import AVSR
|
13 |
-
from pipelines.data.data_module import AVSRDataLoader
|
14 |
-
|
15 |
-
|
16 |
-
class InferencePipeline(torch.nn.Module):
|
17 |
-
def __init__(self, config_filename, detector="retinaface", face_track=False, device="cuda:0"):
|
18 |
-
super(InferencePipeline, self).__init__()
|
19 |
-
assert os.path.isfile(config_filename), f"config_filename: {config_filename} does not exist."
|
20 |
-
|
21 |
-
config = ConfigParser()
|
22 |
-
config.read(config_filename)
|
23 |
-
|
24 |
-
# modality configuration
|
25 |
-
modality = config.get("input", "modality")
|
26 |
-
|
27 |
-
self.modality = modality
|
28 |
-
# data configuration
|
29 |
-
input_v_fps = config.getfloat("input", "v_fps")
|
30 |
-
model_v_fps = config.getfloat("model", "v_fps")
|
31 |
-
|
32 |
-
# model configuration
|
33 |
-
model_path = config.get("model","model_path")
|
34 |
-
model_conf = config.get("model","model_conf")
|
35 |
-
|
36 |
-
# language model configuration
|
37 |
-
rnnlm = config.get("model", "rnnlm")
|
38 |
-
rnnlm_conf = config.get("model", "rnnlm_conf")
|
39 |
-
penalty = config.getfloat("decode", "penalty")
|
40 |
-
ctc_weight = config.getfloat("decode", "ctc_weight")
|
41 |
-
lm_weight = config.getfloat("decode", "lm_weight")
|
42 |
-
beam_size = config.getint("decode", "beam_size")
|
43 |
-
|
44 |
-
self.dataloader = AVSRDataLoader(modality, speed_rate=input_v_fps/model_v_fps, detector=detector)
|
45 |
-
self.model = AVSR(modality, model_path, model_conf, rnnlm, rnnlm_conf, penalty, ctc_weight, lm_weight, beam_size, device)
|
46 |
-
if face_track and self.modality in ["video", "audiovisual"]:
|
47 |
-
if detector == "mediapipe":
|
48 |
-
from pipelines.detectors.mediapipe.detector import LandmarksDetector
|
49 |
-
self.landmarks_detector = LandmarksDetector()
|
50 |
-
if detector == "retinaface":
|
51 |
-
from pipelines.detectors.retinaface.detector import LandmarksDetector
|
52 |
-
self.landmarks_detector = LandmarksDetector(device="cuda:0")
|
53 |
-
else:
|
54 |
-
self.landmarks_detector = None
|
55 |
-
|
56 |
-
|
57 |
-
def process_landmarks(self, data_filename, landmarks_filename):
|
58 |
-
if self.modality == "audio":
|
59 |
-
return None
|
60 |
-
if self.modality in ["video", "audiovisual"]:
|
61 |
-
if isinstance(landmarks_filename, str):
|
62 |
-
landmarks = pickle.load(open(landmarks_filename, "rb"))
|
63 |
-
else:
|
64 |
-
landmarks = self.landmarks_detector(data_filename)
|
65 |
-
return landmarks
|
66 |
-
|
67 |
-
|
68 |
-
def forward(self, data_filename, landmarks_filename=None):
|
69 |
-
assert os.path.isfile(data_filename), f"data_filename: {data_filename} does not exist."
|
70 |
-
landmarks = self.process_landmarks(data_filename, landmarks_filename)
|
71 |
-
data = self.dataloader.load_data(data_filename, landmarks)
|
72 |
-
transcript = self.model.infer(data)
|
73 |
-
return transcript
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|