willwade commited on
Commit
b817428
·
verified ·
1 Parent(s): da62e74

Delete pipeline.py

Browse files
Files changed (1) hide show
  1. pipeline.py +0 -73
pipeline.py DELETED
@@ -1,73 +0,0 @@
1
- #! /usr/bin/env python
2
- # -*- coding: utf-8 -*-
3
-
4
- # Copyright 2023 Imperial College London (Pingchuan Ma)
5
- # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
6
-
7
- import os
8
- import torch
9
- import pickle
10
- from configparser import ConfigParser
11
-
12
- from pipelines.model import AVSR
13
- from pipelines.data.data_module import AVSRDataLoader
14
-
15
-
16
- class InferencePipeline(torch.nn.Module):
17
- def __init__(self, config_filename, detector="retinaface", face_track=False, device="cuda:0"):
18
- super(InferencePipeline, self).__init__()
19
- assert os.path.isfile(config_filename), f"config_filename: {config_filename} does not exist."
20
-
21
- config = ConfigParser()
22
- config.read(config_filename)
23
-
24
- # modality configuration
25
- modality = config.get("input", "modality")
26
-
27
- self.modality = modality
28
- # data configuration
29
- input_v_fps = config.getfloat("input", "v_fps")
30
- model_v_fps = config.getfloat("model", "v_fps")
31
-
32
- # model configuration
33
- model_path = config.get("model","model_path")
34
- model_conf = config.get("model","model_conf")
35
-
36
- # language model configuration
37
- rnnlm = config.get("model", "rnnlm")
38
- rnnlm_conf = config.get("model", "rnnlm_conf")
39
- penalty = config.getfloat("decode", "penalty")
40
- ctc_weight = config.getfloat("decode", "ctc_weight")
41
- lm_weight = config.getfloat("decode", "lm_weight")
42
- beam_size = config.getint("decode", "beam_size")
43
-
44
- self.dataloader = AVSRDataLoader(modality, speed_rate=input_v_fps/model_v_fps, detector=detector)
45
- self.model = AVSR(modality, model_path, model_conf, rnnlm, rnnlm_conf, penalty, ctc_weight, lm_weight, beam_size, device)
46
- if face_track and self.modality in ["video", "audiovisual"]:
47
- if detector == "mediapipe":
48
- from pipelines.detectors.mediapipe.detector import LandmarksDetector
49
- self.landmarks_detector = LandmarksDetector()
50
- if detector == "retinaface":
51
- from pipelines.detectors.retinaface.detector import LandmarksDetector
52
- self.landmarks_detector = LandmarksDetector(device="cuda:0")
53
- else:
54
- self.landmarks_detector = None
55
-
56
-
57
- def process_landmarks(self, data_filename, landmarks_filename):
58
- if self.modality == "audio":
59
- return None
60
- if self.modality in ["video", "audiovisual"]:
61
- if isinstance(landmarks_filename, str):
62
- landmarks = pickle.load(open(landmarks_filename, "rb"))
63
- else:
64
- landmarks = self.landmarks_detector(data_filename)
65
- return landmarks
66
-
67
-
68
- def forward(self, data_filename, landmarks_filename=None):
69
- assert os.path.isfile(data_filename), f"data_filename: {data_filename} does not exist."
70
- landmarks = self.process_landmarks(data_filename, landmarks_filename)
71
- data = self.dataloader.load_data(data_filename, landmarks)
72
- transcript = self.model.infer(data)
73
- return transcript