Spaces:
Running
Running
File size: 2,546 Bytes
47aeb66 21cb4dd 47aeb66 a45c62c 47aeb66 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 |
"""
File: model.py
Author: Elena Ryumina and Dmitry Ryumin
Description: This module provides functions for loading and processing a pre-trained deep learning model
for facial expression recognition.
License: MIT License
"""
import torch
import requests
# Importing necessary components for the Gradio app
from app.config import config_data
from app.model_architectures import ResNet50, LSTMPyTorch, ExprModelV3
from transformers import AutoFeatureExtractor
device = "cpu"
def load_model(model_url, model_path):
try:
with requests.get(model_url, stream=True) as response:
with open(model_path, "wb") as file:
for chunk in response.iter_content(chunk_size=8192):
file.write(chunk)
return model_path
except Exception as e:
print(f"Error loading model: {e}")
return None
gradients = {}
def get_gradients(name):
def hook(model, input, output):
gradients[name] = output
return hook
activations = {}
def get_activations(name):
def hook(model, input, output):
activations[name] = output.detach()
return hook
test_static = torch.rand(1, 3, 224, 224)
test_dynamic = torch.rand(1, 10, 512)
test_audio = torch.rand(1, 64000)
path_static = load_model(config_data.model_static_url, config_data.model_static_path)
pth_model_static = ResNet50(7, channels=3)
pth_model_static.load_state_dict(torch.load(path_static))
pth_model_static.to(device)
pth_model_static.eval()
pth_model_static(test_static.to(device))
pth_model_static.layer4.register_full_backward_hook(get_gradients('layer4'))
pth_model_static.layer4.register_forward_hook(get_activations('layer4'))
pth_model_static.fc1.register_forward_hook(get_activations('features'))
path_dynamic = load_model(config_data.model_dynamic_url, config_data.model_dynamic_path)
pth_model_dynamic = LSTMPyTorch()
pth_model_dynamic.load_state_dict(torch.load(path_dynamic))
pth_model_dynamic.to(device)
pth_model_dynamic.eval()
pth_model_dynamic(test_dynamic.to(device))
path_audio_model_1 = "audeering/wav2vec2-large-robust-12-ft-emotion-msp-dim"
path_audio_model_2 = load_model(config_data.model_audio_url, config_data.model_audio_path)
audio_processor = AutoFeatureExtractor.from_pretrained(path_audio_model_1)
audio_model = ExprModelV3.from_pretrained(path_audio_model_1)
audio_model.load_state_dict(torch.load(path_audio_model_2, map_location=torch.device(device))["model_state_dict"])
audio_model.to(device)
audio_model.eval()
audio_model(test_audio.to(device)) |