wli3221134 commited on
Commit
719b808
·
verified ·
1 Parent(s): 0a63b23

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +33 -36
app.py CHANGED
@@ -1,42 +1,40 @@
1
- import torch
2
-
3
- import spaces
4
- import os
5
  import gradio as gr
6
- from huggingface_hub import HfApi
7
- from gradio_client.exceptions import AuthenticationError
8
-
9
  from model import Wav2Vec2BERT_Llama # 自定义模型模块
10
  import dataset # 自定义数据集模块
 
11
 
12
- @spaces.GPU
13
- def dummy(): # just a dummy
14
- pass
15
 
16
  # 初始化设备
17
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
18
 
19
  # 初始化模型
20
- # def load_model():
21
- # model = Wav2Vec2BERT_Llama().to(device)
22
- # checkpoint_path = "ckpt/model_checkpoint.pth"
23
- # if os.path.exists(checkpoint_path):
24
- # checkpoint = torch.load(checkpoint_path)
25
- # model_state_dict = checkpoint['model_state_dict']
26
-
27
- # # 处理模型状态字典的 key
28
- # if hasattr(model, 'module') and not any(key.startswith('module.') for key in model_state_dict.keys()):
29
- # model_state_dict = {'module.' + key: value for key, value in model_state_dict.items()}
30
- # elif not hasattr(model, 'module') and any(key.startswith('module.') for key in model_state_dict.keys()):
31
- # model_state_dict = {key.replace('module.', ''): value for key, value in model_state_dict.items()}
32
-
33
- # model.load_state_dict(model_state_dict)
34
- # model.eval()
35
- # else:
36
- # raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}")
37
- # return model
38
-
39
- # model = load_model()
 
 
 
 
 
40
 
41
  # 检测函数
42
  def detect(dataset, model):
@@ -96,13 +94,13 @@ def gradio_ui():
96
  interface = gr.Interface(
97
  fn=detection_wrapper, # 主函数
98
  inputs=[
99
- gr.Audio(type="filepath", label="Demonstration Audio 1"),
100
  gr.Dropdown(choices=["bonafide", "spoof"], value="bonafide", label="Label 1"),
101
- gr.Audio(type="filepath", label="Demonstration Audio 2"),
102
  gr.Dropdown(choices=["bonafide", "spoof"], value="bonafide", label="Label 2"),
103
- gr.Audio(type="filepath", label="Demonstration Audio 3"),
104
  gr.Dropdown(choices=["bonafide", "spoof"], value="bonafide", label="Label 3"),
105
- gr.Audio(type="filepath", label="Query Audio (Audio for Detection)")
106
  ],
107
  outputs=gr.JSON(label="Detection Results"),
108
  title="Audio Deepfake Detection System",
@@ -110,7 +108,6 @@ def gradio_ui():
110
  )
111
  return interface
112
 
113
-
114
  if __name__ == "__main__":
115
  demo = gradio_ui()
116
- demo.launch(share=False)
 
 
 
 
 
1
  import gradio as gr
2
+ import os
3
+ import torch
 
4
  from model import Wav2Vec2BERT_Llama # 自定义模型模块
5
  import dataset # 自定义数据集模块
6
+ from huggingface_hub import hf_hub_download
7
 
 
 
 
8
 
9
  # 初始化设备
10
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
11
 
12
  # 初始化模型
13
+ def load_model():
14
+ model = Wav2Vec2BERT_Llama().to(device)
15
+ checkpoint_path = hf_hub_download(
16
+ repo_id="amphion/deepfake_detection",
17
+ filename="checkpoints_wav2vec2bert_ft_llama_labels_ASVspoof2019_RandomPrompts_6/model_checkpoint.pth"
18
+ )
19
+ # checkpoint_path = "ckpt/model_checkpoint.pth"
20
+ if os.path.exists(checkpoint_path):
21
+ checkpoint = torch.load(checkpoint_path)
22
+ model_state_dict = checkpoint['model_state_dict']
23
+ threshold = 0.9996
24
+
25
+ # 处理模型状态字典的 key
26
+ if hasattr(model, 'module') and not any(key.startswith('module.') for key in model_state_dict.keys()):
27
+ model_state_dict = {'module.' + key: value for key, value in model_state_dict.items()}
28
+ elif not hasattr(model, 'module') and any(key.startswith('module.') for key in model_state_dict.keys()):
29
+ model_state_dict = {key.replace('module.', ''): value for key, value in model_state_dict.items()}
30
+
31
+ model.load_state_dict(model_state_dict)
32
+ model.eval()
33
+ else:
34
+ raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}")
35
+ return model, threshold
36
+
37
+ model, threshold = load_model()
38
 
39
  # 检测函数
40
  def detect(dataset, model):
 
94
  interface = gr.Interface(
95
  fn=detection_wrapper, # 主函数
96
  inputs=[
97
+ gr.Audio(source="upload", type="filepath", label="Demonstration Audio 1"),
98
  gr.Dropdown(choices=["bonafide", "spoof"], value="bonafide", label="Label 1"),
99
+ gr.Audio(source="upload", type="filepath", label="Demonstration Audio 2"),
100
  gr.Dropdown(choices=["bonafide", "spoof"], value="bonafide", label="Label 2"),
101
+ gr.Audio(source="upload", type="filepath", label="Demonstration Audio 3"),
102
  gr.Dropdown(choices=["bonafide", "spoof"], value="bonafide", label="Label 3"),
103
+ gr.Audio(source="upload", type="filepath", label="Query Audio (Audio for Detection)")
104
  ],
105
  outputs=gr.JSON(label="Detection Results"),
106
  title="Audio Deepfake Detection System",
 
108
  )
109
  return interface
110
 
 
111
  if __name__ == "__main__":
112
  demo = gradio_ui()
113
+ demo.launch()