Spaces:
Sleeping
Sleeping
import gradio as gr | |
import pandas as pd | |
from services.huggingface import init_huggingface, update_dataset | |
from services.json_generator import generate_json | |
from ui.form_components import ( | |
create_header_tab, | |
create_task_tab, | |
create_measures_tab, | |
create_system_tab, | |
create_software_tab, | |
create_infrastructure_tab, | |
create_environment_tab, | |
create_quality_tab, | |
create_hash_tab | |
) | |
# Initialize Hugging Face | |
init_huggingface() | |
# Create Gradio interface | |
with gr.Blocks() as demo: | |
gr.Markdown("## ML-related Data Collection Form") | |
gr.Markdown("Welcome to this Huggingface space that helps you fill in a form for monitoring the energy consumption of an AI model.") | |
csv_upload = gr.File(label="Upload CSV", file_types=[".csv"]) | |
gr.Label("Please upload a CSV file with the data you want to analyze.") | |
# Create form tabs | |
header_components = create_header_tab() | |
task_components = create_task_tab() | |
measures_components = create_measures_tab() | |
system_components = create_system_tab() | |
software_components = create_software_tab() | |
infrastructure_components = create_infrastructure_tab() | |
environment_components = create_environment_tab() | |
quality_components = create_quality_tab() | |
hash_components = create_hash_tab() | |
# Gather all form components in the order they appear in the inputs | |
all_form_components = ( | |
header_components # 11 items (indices 0-10) | |
+ task_components # 28 items (indices 11-38) | |
+ measures_components # 14 items (indices 39-52) | |
+ system_components # 3 items (indices 53-55) | |
+ software_components # 2 items (indices 56-57) | |
+ infrastructure_components # 10 items (indices 58-67) | |
+ environment_components # 7 items (indices 68-74) | |
+ quality_components # 1 item (index 75) | |
+ hash_components # 3 items (indices 76-78) | |
) | |
# Parse CSV and update form values | |
def parse_csv_and_update_form(csv_file, *current_values): | |
updated_values = list(current_values) | |
if csv_file is None: | |
return updated_values | |
try: | |
df = pd.read_csv(csv_file.name) | |
csv_data = df.iloc[0].to_dict() | |
# ========== HEADER ========== | |
updated_values[3] = csv_data.get('run_id', '') # reportId (index 3) | |
updated_values[4] = csv_data.get('timestamp', '') # reportDatetime (4) | |
updated_values[8] = csv_data.get('project_name', '') # publisher_projectName (8) | |
# ========== SYSTEM ========== | |
updated_values[54] = csv_data.get('os', '') # os (index 53) | |
updated_values[55] = "" # distribution (54) | |
updated_values[56] = "" # distributionVersion (55) | |
# ========== MEASURES ========== | |
updated_values[40] = csv_data.get('tracking_mode', '') # measurementMethod (39) | |
updated_values[48] = "kWh" # unit (47) | |
updated_values[51] = csv_data.get('energy_consumed', '') # powerConsumption (50) | |
# Duration conversion (hours → seconds) | |
if 'duration' in csv_data: | |
try: | |
hours = float(csv_data['duration']) | |
updated_values[52] = str(round(hours * 3600, 2)) # measurementDuration (51) | |
except: | |
updated_values[52] = "" | |
updated_values[53] = csv_data.get('timestamp', '') # measurementDateTime (52) | |
# ========== SOFTWARE ========== | |
updated_values[57] = "Python" # language (56) | |
updated_values[58] = csv_data.get('python_version', '') # version_software (57) | |
# ========== INFRASTRUCTURE ========== | |
# infraType (58) - Dropdown | |
on_cloud = str(csv_data.get('on_cloud', 'No')).lower().strip() | |
updated_values[59] = "publicCloud" if on_cloud == "yes" else "onPremise" | |
# Cloud fields (59-60) | |
updated_values[59] = csv_data.get('cloud_provider', '') if on_cloud == "yes" else "" | |
updated_values[60] = csv_data.get('cloud_region', '') if on_cloud == "yes" else "" | |
# Component logic (61-67) | |
gpu_count = int(csv_data.get('gpu_count', 0)) | |
cpu_count = int(csv_data.get('cpu_count', 0)) | |
if gpu_count > 0: | |
updated_values[62] = str(gpu_count) # nbComponent (62) | |
model = csv_data.get('gpu_model', '') | |
elif cpu_count > 0: | |
u | |
updated_values[62] = str(cpu_count) # nbComponent (62) | |
model = csv_data.get('cpu_model', '') | |
else: | |
model = "" | |
# Memory size (63) | |
ram_size = csv_data.get('ram_total_size', '') | |
updated_values[63] = f"{ram_size} GB" if ram_size and float(ram_size) > 0 else "" | |
# Split model into manufacturer/family/series (64-66) | |
if model: | |
parts = model.replace("(R)", "").replace("(TM)", "").split() | |
updated_values[65] = parts[0] if parts else "" # manufacturer_infra (64) | |
updated_values[66] = " ".join(parts[1:3]) if len(parts) >= 3 else "" # family (65) | |
updated_values[67] = " ".join(parts[3:]) if len(parts) > 3 else "" # series (66) | |
else: | |
updated_values[65] = updated_values[66] = updated_values[67] = "" | |
updated_values[67] = "" # share (67) | |
# ========== ENVIRONMENT ========== | |
updated_values[69] = csv_data.get('country_name', '') # country (68) | |
updated_values[70] = csv_data.get('latitude', '') # latitude (69) | |
updated_values[71] = csv_data.get('longitude', '') # longitude (70) | |
updated_values[72] = csv_data.get('region', '') # location (71) | |
except Exception as e: | |
print(f"CSV Processing Error: {str(e)}") | |
return updated_values | |
# Parse CSV and update form values | |
csv_upload.change( | |
fn=parse_csv_and_update_form, | |
inputs=[csv_upload] + all_form_components, | |
outputs=all_form_components | |
) | |
# Submit and Download Buttons | |
submit_button = gr.Button("Submit") | |
output = gr.Textbox(label="Output", lines=1) | |
json_output = gr.Textbox(visible=False) | |
file_output = gr.File(label="Downloadable JSON") | |
# Event Handlers | |
submit_button.click( | |
generate_json, | |
inputs=[ | |
*header_components, | |
*task_components, | |
*measures_components, | |
*system_components, | |
*software_components, | |
*infrastructure_components, | |
*environment_components, | |
*quality_components, | |
*hash_components | |
], | |
outputs=[output, file_output, json_output] | |
).then( | |
update_dataset, | |
inputs=json_output, | |
outputs=output | |
) | |
print(all_form_components) | |
print(len(all_form_components)) | |
if __name__ == "__main__": | |
demo.launch() | |
print(all_form_components) | |