Spaces:
Running
on
Zero
Running
on
Zero
wli3221134
commited on
Update app.py
Browse files
app.py
CHANGED
@@ -10,41 +10,40 @@ from huggingface_hub import hf_hub_download
|
|
10 |
def dummy(): # just a dummy
|
11 |
pass
|
12 |
|
13 |
-
#
|
14 |
-
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
15 |
-
print('device:', device)
|
16 |
-
|
17 |
-
# 初始化模型
|
18 |
def load_model():
|
19 |
-
model = Wav2Vec2BERT_Llama().to(device)
|
20 |
checkpoint_path = hf_hub_download(
|
21 |
repo_id="amphion/deepfake_detection",
|
22 |
filename="checkpoints_wav2vec2bert_ft_llama_labels_ASVspoof2019_RandomPrompts_6/model_checkpoint.pth",
|
23 |
repo_type="model"
|
24 |
)
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
model_state_dict = checkpoint['model_state_dict']
|
29 |
-
threshold = 0.9996
|
30 |
|
31 |
-
|
32 |
-
if hasattr(model, 'module') and not any(key.startswith('module.') for key in model_state_dict.keys()):
|
33 |
-
model_state_dict = {'module.' + key: value for key, value in model_state_dict.items()}
|
34 |
-
elif not hasattr(model, 'module') and any(key.startswith('module.') for key in model_state_dict.keys()):
|
35 |
-
model_state_dict = {key.replace('module.', ''): value for key, value in model_state_dict.items()}
|
36 |
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
42 |
|
43 |
-
model
|
|
|
44 |
|
45 |
-
# 检测函数
|
46 |
-
def detect(dataset):
|
47 |
-
"""进行音频伪造检测"""
|
48 |
with torch.no_grad():
|
49 |
for batch in dataset:
|
50 |
main_features = {
|
@@ -57,43 +56,31 @@ def detect(dataset):
|
|
57 |
} for pf in batch['prompt_features']]
|
58 |
|
59 |
prompt_labels = batch['prompt_labels'].to(device)
|
60 |
-
# 模型的前向传播逻辑 (需要补充具体实现)
|
61 |
outputs = model({
|
62 |
'main_features': main_features,
|
63 |
'prompt_features': prompt_features,
|
64 |
'prompt_labels': prompt_labels
|
65 |
})
|
66 |
|
67 |
-
avg_scores = outputs['avg_logits'].softmax(dim=-1)
|
68 |
-
deepfake_scores = avg_scores[:, 1].cpu()
|
69 |
-
is_fake =
|
70 |
-
|
71 |
-
result = {"is_fake": is_fake, "confidence": deepfake_scores[0]} # 示例返回值
|
72 |
return result
|
73 |
|
74 |
-
#
|
75 |
def audio_deepfake_detection(demonstrations, query_audio_path):
|
76 |
-
"""
|
77 |
-
音频伪造检测函数
|
78 |
-
:param demonstrations: 演示音频路径和标签的列表
|
79 |
-
:param query_audio_path: 查询音频路径
|
80 |
-
:return: 检测结果
|
81 |
-
"""
|
82 |
demonstration_paths = [audio[0] for audio in demonstrations if audio[0] is not None]
|
83 |
demonstration_labels = [audio[1] for audio in demonstrations if audio[1] is not None]
|
84 |
if len(demonstration_paths) != len(demonstration_labels):
|
85 |
demonstration_labels = demonstration_labels[:len(demonstration_paths)]
|
86 |
-
|
87 |
-
print(f"Demonstration audio labels: {demonstration_labels}")
|
88 |
-
print(f"Query audio path: {query_audio_path}")
|
89 |
-
|
90 |
# 数据集处理
|
91 |
audio_dataset = dataset.DemoDataset(demonstration_paths, demonstration_labels, query_audio_path)
|
92 |
-
|
93 |
-
#
|
94 |
-
result =
|
95 |
-
|
96 |
-
# 返回结果
|
97 |
return {
|
98 |
"Is AI Generated": result["is_fake"],
|
99 |
"Confidence": f"{result['confidence']:.2f}%"
|
|
|
10 |
def dummy(): # just a dummy
|
11 |
pass
|
12 |
|
13 |
+
# 修改 load_model 函数
|
|
|
|
|
|
|
|
|
14 |
def load_model():
|
|
|
15 |
checkpoint_path = hf_hub_download(
|
16 |
repo_id="amphion/deepfake_detection",
|
17 |
filename="checkpoints_wav2vec2bert_ft_llama_labels_ASVspoof2019_RandomPrompts_6/model_checkpoint.pth",
|
18 |
repo_type="model"
|
19 |
)
|
20 |
+
if not os.path.exists(checkpoint_path):
|
21 |
+
raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}")
|
22 |
+
return checkpoint_path
|
|
|
|
|
23 |
|
24 |
+
checkpoint_path = load_model()
|
|
|
|
|
|
|
|
|
25 |
|
26 |
+
# 将 detect 函数移到 GPU 装饰器下
|
27 |
+
@spaces.GPU
|
28 |
+
def detect_on_gpu(dataset):
|
29 |
+
"""在 GPU 上进行音频伪造检测"""
|
30 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
31 |
+
model = Wav2Vec2BERT_Llama().to(device)
|
32 |
+
|
33 |
+
# 加载模型权重
|
34 |
+
checkpoint = torch.load(checkpoint_path, map_location=device)
|
35 |
+
model_state_dict = checkpoint['model_state_dict']
|
36 |
+
threshold = 0.9996
|
37 |
+
|
38 |
+
# 处理模型状态字典的 key
|
39 |
+
if hasattr(model, 'module') and not any(key.startswith('module.') for key in model_state_dict.keys()):
|
40 |
+
model_state_dict = {'module.' + key: value for key, value in model_state_dict.items()}
|
41 |
+
elif not hasattr(model, 'module') and any(key.startswith('module.') for key in model_state_dict.keys()):
|
42 |
+
model_state_dict = {key.replace('module.', ''): value for key, value in model_state_dict.items()}
|
43 |
|
44 |
+
model.load_state_dict(model_state_dict)
|
45 |
+
model.eval()
|
46 |
|
|
|
|
|
|
|
47 |
with torch.no_grad():
|
48 |
for batch in dataset:
|
49 |
main_features = {
|
|
|
56 |
} for pf in batch['prompt_features']]
|
57 |
|
58 |
prompt_labels = batch['prompt_labels'].to(device)
|
|
|
59 |
outputs = model({
|
60 |
'main_features': main_features,
|
61 |
'prompt_features': prompt_features,
|
62 |
'prompt_labels': prompt_labels
|
63 |
})
|
64 |
|
65 |
+
avg_scores = outputs['avg_logits'].softmax(dim=-1)
|
66 |
+
deepfake_scores = avg_scores[:, 1].cpu()
|
67 |
+
is_fake = deepfake_scores[0] > threshold
|
68 |
+
result = {"is_fake": is_fake, "confidence": deepfake_scores[0]}
|
|
|
69 |
return result
|
70 |
|
71 |
+
# 修改音频伪造检测主函数
|
72 |
def audio_deepfake_detection(demonstrations, query_audio_path):
|
|
|
|
|
|
|
|
|
|
|
|
|
73 |
demonstration_paths = [audio[0] for audio in demonstrations if audio[0] is not None]
|
74 |
demonstration_labels = [audio[1] for audio in demonstrations if audio[1] is not None]
|
75 |
if len(demonstration_paths) != len(demonstration_labels):
|
76 |
demonstration_labels = demonstration_labels[:len(demonstration_paths)]
|
77 |
+
|
|
|
|
|
|
|
78 |
# 数据集处理
|
79 |
audio_dataset = dataset.DemoDataset(demonstration_paths, demonstration_labels, query_audio_path)
|
80 |
+
|
81 |
+
# 调用 GPU 检测函数
|
82 |
+
result = detect_on_gpu(audio_dataset)
|
83 |
+
|
|
|
84 |
return {
|
85 |
"Is AI Generated": result["is_fake"],
|
86 |
"Confidence": f"{result['confidence']:.2f}%"
|