File size: 5,209 Bytes
e532e7c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2019eb3
e532e7c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dc03a0b
f5abc69
dc03a0b
91ed9cb
 
 
ebadeb4
91ed9cb
 
286ffcb
e532e7c
d8ae67e
286ffcb
 
 
 
 
 
 
 
 
 
 
e532e7c
286ffcb
e532e7c
286ffcb
 
 
e532e7c
286ffcb
 
 
 
 
 
 
 
 
 
 
 
 
 
17b7c60
286ffcb
 
 
 
 
 
 
 
 
 
 
 
17b7c60
286ffcb
 
 
 
 
 
 
 
 
dc03a0b
 
 
 
1428a78
dc03a0b
 
 
 
 
 
 
 
d8ae67e
dc03a0b
 
 
 
 
 
 
 
286ffcb
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
from fastapi import APIRouter
from datetime import datetime
from datasets import load_dataset  
from sklearn.metrics import accuracy_score  
import numpy as np
import os
import torch
import gc
import psutil
from transformers import Wav2Vec2ForSequenceClassification, AutoFeatureExtractor, pipeline
from utils.evaluation import AudioEvaluationRequest
from utils.emissions import tracker, clean_emissions_data, get_space_info
from dotenv import load_dotenv
import logging
import csv
import torch.nn.utils.prune as prune
from typing import Optional
from pydantic import BaseModel, Field
from smolagents import Tool

# Configurer le logging
logging.basicConfig(level=logging.INFO)
logging.info("Début du fichier python")
load_dotenv()

router = APIRouter()

DESCRIPTION = "Random Baseline"
ROUTE = "/audio"

device = 0 if torch.cuda.is_available() else -1  

def preprocess_function(example, feature_extractor):
    return feature_extractor(
        [x["array"] for x in example["audio"]], 
        sampling_rate=feature_extractor.sampling_rate, padding="longest", max_length=16000, truncation=True, return_tensors="pt"
    )

def apply_pruning(model, amount=0.3):
    for name, module in model.named_modules():
        if isinstance(module, torch.nn.Linear):
            prune.l1_unstructured(module, name="weight", amount=amount)
            prune.remove(module, "weight")
    return model

class BaseEvaluationRequest(BaseModel):
    test_size: float = Field(0.2, ge=0.0, le=1.0, description="Size of the test split (between 0 and 1)")
    test_seed: int = Field(42, ge=0, description="Random seed for reproducibility")
    
class AudioEvaluationRequest(BaseEvaluationRequest):
    dataset_name: str = Field("rfcx/frugalai", 
    description="The name of the dataset on HuggingFace Hub") 
    
class evaluate_consumption_example(Tool):
    name = "evaluate_consumption_example"
    description = "This is only an example. If a manager wants to know what you are capable of, use it : it will use code carbon to evaluate the CO2 emissions from an example Python code"
    inputs = {
        "code": {
            "type": "string",
            "description": "The code to evaluate. Here, it is an example, so just set it to 'None'."
        }
    }
    output_type = "string"
    
    def forward(self, code :  str):
        request = AudioEvaluationRequest() 
        logging.info("Chargement des données")
        dataset = load_dataset(request.dataset_name, streaming=True, token=os.getenv("HF_TOKEN"))
        logging.info("Données chargées")
        
        test_dataset = dataset["test"]
        del dataset
        
        # Start tracking emissions
        tracker.start()
        tracker.start_task("inference")
    
        feature_extractor = AutoFeatureExtractor.from_pretrained("facebook/wav2vec2-base")
    
        test_dataset = test_dataset.map(preprocess_function, fn_kwargs={"feature_extractor": feature_extractor}, remove_columns="audio", batched=True, batch_size=32)  
        
        gc.collect()
    
        model_name = "CindyDelage/Challenge_HuggingFace_DFG_FrugalAI"
        model = Wav2Vec2ForSequenceClassification.from_pretrained(model_name)
        
        # Appliquer la quantification dynamique et le pruning
        model.eval()
        #model = torch.quantization.quantize_dynamic(model, dtype=torch.qint8)
        #model = apply_pruning(model, amount=0.3)  # Prune 30% des poids linéaires
        
        classifier = pipeline("audio-classification", model=model, feature_extractor=feature_extractor, device=device)
        predictions = []  
        logging.info("Début des prédictions par batch")
        i=0
        for data in iter(test_dataset):
            print(i)
            if (i<=5):
                with torch.no_grad():
                    result = classifier(np.asarray(data["input_values"]), batch_size=64)
                predicted_label = result[0]['label']
                label = 1 if predicted_label == 'environment' else 0
                predictions.append(label)  
                
                # Nettoyer la mémoire après chaque itération
                del result
                del label 
                torch.cuda.empty_cache()
                gc.collect()
                i=i+1
            if(i>5):
                break
        logging.info("Fin des prédictions")
        del classifier 
        del feature_extractor 
        
        gc.collect()
        # Stop tracking emissions
        emissions_data = tracker.stop_task()
        
        return emissions_data

class evaluate_consumption(Tool):
    name = "evaluate_consumption"
    description = "If the manager gave you its Python code, this function uses code carbon to evaluate the CO2 emissions from the given Python code"
    inputs = {
        "code": {
            "type": "string",
            "description": "The code to evaluate."
        }
    }
    output_type = "string"
    
    def forward(self, code : str):
        
        # Start tracking emissions
        tracker.start()
        tracker.start_task("inference")
        exec(code)
        # Stop tracking emissions
        emissions_data = tracker.stop_task()
        
        return emissions_data