wli3221134 commited on
Commit
4a81ee5
·
verified ·
1 Parent(s): f96cfa2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +34 -47
app.py CHANGED
@@ -10,41 +10,40 @@ from huggingface_hub import hf_hub_download
10
  def dummy(): # just a dummy
11
  pass
12
 
13
- # 初始化设备
14
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
15
- print('device:', device)
16
-
17
- # 初始化模型
18
  def load_model():
19
- model = Wav2Vec2BERT_Llama().to(device)
20
  checkpoint_path = hf_hub_download(
21
  repo_id="amphion/deepfake_detection",
22
  filename="checkpoints_wav2vec2bert_ft_llama_labels_ASVspoof2019_RandomPrompts_6/model_checkpoint.pth",
23
  repo_type="model"
24
  )
25
- # checkpoint_path = "ckpt/model_checkpoint.pth"
26
- if os.path.exists(checkpoint_path):
27
- checkpoint = torch.load(checkpoint_path, map_location=torch.device('cpu'))
28
- model_state_dict = checkpoint['model_state_dict']
29
- threshold = 0.9996
30
 
31
- # 处理模型状态字典的 key
32
- if hasattr(model, 'module') and not any(key.startswith('module.') for key in model_state_dict.keys()):
33
- model_state_dict = {'module.' + key: value for key, value in model_state_dict.items()}
34
- elif not hasattr(model, 'module') and any(key.startswith('module.') for key in model_state_dict.keys()):
35
- model_state_dict = {key.replace('module.', ''): value for key, value in model_state_dict.items()}
36
 
37
- model.load_state_dict(model_state_dict)
38
- model.eval()
39
- else:
40
- raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}")
41
- return model, threshold
 
 
 
 
 
 
 
 
 
 
 
 
42
 
43
- model, threshold = load_model()
 
44
 
45
- # 检测函数
46
- def detect(dataset):
47
- """进行音频伪造检测"""
48
  with torch.no_grad():
49
  for batch in dataset:
50
  main_features = {
@@ -57,43 +56,31 @@ def detect(dataset):
57
  } for pf in batch['prompt_features']]
58
 
59
  prompt_labels = batch['prompt_labels'].to(device)
60
- # 模型的前向传播逻辑 (需要补充具体实现)
61
  outputs = model({
62
  'main_features': main_features,
63
  'prompt_features': prompt_features,
64
  'prompt_labels': prompt_labels
65
  })
66
 
67
- avg_scores = outputs['avg_logits'].softmax(dim=-1) # [batch_size, 2]
68
- deepfake_scores = avg_scores[:, 1].cpu() # [batch_size]
69
- is_fake = True if deepfake_scores[0] > threshold else False
70
- # 假设 result 是模型返回的结果
71
- result = {"is_fake": is_fake, "confidence": deepfake_scores[0]} # 示例返回值
72
  return result
73
 
74
- # 音频伪造检测主函数
75
  def audio_deepfake_detection(demonstrations, query_audio_path):
76
- """
77
- 音频伪造检测函数
78
- :param demonstrations: 演示音频路径和标签的列表
79
- :param query_audio_path: 查询音频路径
80
- :return: 检测结果
81
- """
82
  demonstration_paths = [audio[0] for audio in demonstrations if audio[0] is not None]
83
  demonstration_labels = [audio[1] for audio in demonstrations if audio[1] is not None]
84
  if len(demonstration_paths) != len(demonstration_labels):
85
  demonstration_labels = demonstration_labels[:len(demonstration_paths)]
86
- print(f"Demonstration audio paths: {demonstration_paths}")
87
- print(f"Demonstration audio labels: {demonstration_labels}")
88
- print(f"Query audio path: {query_audio_path}")
89
-
90
  # 数据集处理
91
  audio_dataset = dataset.DemoDataset(demonstration_paths, demonstration_labels, query_audio_path)
92
-
93
- # 调用检测函数
94
- result = detect(audio_dataset)
95
-
96
- # 返回结果
97
  return {
98
  "Is AI Generated": result["is_fake"],
99
  "Confidence": f"{result['confidence']:.2f}%"
 
10
  def dummy(): # just a dummy
11
  pass
12
 
13
+ # 修改 load_model 函数
 
 
 
 
14
  def load_model():
 
15
  checkpoint_path = hf_hub_download(
16
  repo_id="amphion/deepfake_detection",
17
  filename="checkpoints_wav2vec2bert_ft_llama_labels_ASVspoof2019_RandomPrompts_6/model_checkpoint.pth",
18
  repo_type="model"
19
  )
20
+ if not os.path.exists(checkpoint_path):
21
+ raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}")
22
+ return checkpoint_path
 
 
23
 
24
+ checkpoint_path = load_model()
 
 
 
 
25
 
26
+ # 将 detect 函数移到 GPU 装饰器下
27
+ @spaces.GPU
28
+ def detect_on_gpu(dataset):
29
+ """在 GPU 上进行音频伪造检测"""
30
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
31
+ model = Wav2Vec2BERT_Llama().to(device)
32
+
33
+ # 加载模型权重
34
+ checkpoint = torch.load(checkpoint_path, map_location=device)
35
+ model_state_dict = checkpoint['model_state_dict']
36
+ threshold = 0.9996
37
+
38
+ # 处理模型状态字典的 key
39
+ if hasattr(model, 'module') and not any(key.startswith('module.') for key in model_state_dict.keys()):
40
+ model_state_dict = {'module.' + key: value for key, value in model_state_dict.items()}
41
+ elif not hasattr(model, 'module') and any(key.startswith('module.') for key in model_state_dict.keys()):
42
+ model_state_dict = {key.replace('module.', ''): value for key, value in model_state_dict.items()}
43
 
44
+ model.load_state_dict(model_state_dict)
45
+ model.eval()
46
 
 
 
 
47
  with torch.no_grad():
48
  for batch in dataset:
49
  main_features = {
 
56
  } for pf in batch['prompt_features']]
57
 
58
  prompt_labels = batch['prompt_labels'].to(device)
 
59
  outputs = model({
60
  'main_features': main_features,
61
  'prompt_features': prompt_features,
62
  'prompt_labels': prompt_labels
63
  })
64
 
65
+ avg_scores = outputs['avg_logits'].softmax(dim=-1)
66
+ deepfake_scores = avg_scores[:, 1].cpu()
67
+ is_fake = deepfake_scores[0] > threshold
68
+ result = {"is_fake": is_fake, "confidence": deepfake_scores[0]}
 
69
  return result
70
 
71
+ # 修改音频伪造检测主函数
72
  def audio_deepfake_detection(demonstrations, query_audio_path):
 
 
 
 
 
 
73
  demonstration_paths = [audio[0] for audio in demonstrations if audio[0] is not None]
74
  demonstration_labels = [audio[1] for audio in demonstrations if audio[1] is not None]
75
  if len(demonstration_paths) != len(demonstration_labels):
76
  demonstration_labels = demonstration_labels[:len(demonstration_paths)]
77
+
 
 
 
78
  # 数据集处理
79
  audio_dataset = dataset.DemoDataset(demonstration_paths, demonstration_labels, query_audio_path)
80
+
81
+ # 调用 GPU 检测函数
82
+ result = detect_on_gpu(audio_dataset)
83
+
 
84
  return {
85
  "Is AI Generated": result["is_fake"],
86
  "Confidence": f"{result['confidence']:.2f}%"