Spaces:
Sleeping
Sleeping
import os | |
import requests | |
from typing import Optional | |
from fastapi import FastAPI, Header, HTTPException, BackgroundTasks | |
from fastapi.responses import FileResponse | |
from huggingface_hub.hf_api import HfApi | |
from .models import config, WebhookPayload | |
WEBHOOK_SECRET = os.getenv("WEBHOOK_SECRET") | |
HF_ACCESS_TOKEN = os.getenv("HF_ACCESS_TOKEN") | |
AUTOTRAIN_API_URL = "https://api.autotrain.huggingface.co" | |
AUTOTRAIN_UI_URL = "https://ui.autotrain.huggingface.co" | |
app = FastAPI() | |
async def home(): | |
return FileResponse("home.html") | |
async def post_webhook( | |
payload: WebhookPayload, | |
task_queue: BackgroundTasks, | |
x_webhook_secret: Optional[str] = Header(default=None), | |
): | |
if x_webhook_secret is None: | |
raise HTTPException(401) | |
if x_webhook_secret != WEBHOOK_SECRET: | |
raise HTTPException(403) | |
if not ( | |
payload.event.action == "update" | |
and payload.event.scope.startswith("repo.content") | |
and payload.repo.name == config.input_dataset | |
and payload.repo.type == "dataset" | |
): | |
# no-op | |
return {"processed": False} | |
task_queue.add_task( | |
schedule_retrain, | |
payload | |
) | |
return {"processed": True} | |
def schedule_retrain(payload: WebhookPayload): | |
# Create the autotrain project | |
try: | |
project = AutoTrain.create_project(payload) | |
AutoTrain.add_data(project_id=project["id"]) | |
AutoTrain.start_processing(project_id=project["id"]) | |
except requests.HTTPError as err: | |
print("ERROR while requesting AutoTrain API:") | |
print(f" code: {err.response.status_code}") | |
print(f" {err.response.json()}") | |
raise | |
# Notify in the community tab | |
notify_success(project["id"]) | |
return {"processed": True} | |
class AutoTrain: | |
def create_project(payload: WebhookPayload) -> dict: | |
project_resp = requests.post( | |
f"{AUTOTRAIN_API_URL}/projects/create", | |
json={ | |
"username": config.target_namespace, | |
"proj_name": f"{config.autotrain_project_prefix}-{payload.repo.headSha[:7]}", | |
"task": 18, # image-multi-class-classification | |
"config": { | |
"hub-model": config.input_model, | |
"max_models": 1, | |
"language": "unk", | |
} | |
}, | |
headers={ | |
"Authorization": f"Bearer {HF_ACCESS_TOKEN}" | |
} | |
) | |
project_resp.raise_for_status() | |
return project_resp.json() | |
def add_data(project_id:int): | |
requests.post( | |
f"{AUTOTRAIN_API_URL}/projects/{project_id}/data/dataset", | |
json={ | |
"dataset_id": config.input_dataset, | |
"dataset_split": "train", | |
"split": 4, | |
"col_mapping": { | |
"text": "contents", | |
"label": "key issue", | |
} | |
}, | |
headers={ | |
"Authorization": f"Bearer {HF_ACCESS_TOKEN}", | |
} | |
).raise_for_status() | |
def start_processing(project_id: int): | |
resp = requests.post( | |
f"{AUTOTRAIN_API_URL}/projects/{project_id}/data/start_processing", | |
headers={ | |
"Authorization": f"Bearer {HF_ACCESS_TOKEN}", | |
} | |
) | |
resp.raise_for_status() | |
return resp | |
def notify_success(project_id: int): | |
message = NOTIFICATION_TEMPLATE.format( | |
input_model=config.input_model, | |
input_dataset=config.input_dataset, | |
project_id=project_id, | |
ui_url=AUTOTRAIN_UI_URL, | |
) | |
return HfApi(token=HF_ACCESS_TOKEN).create_discussion( | |
repo_id=config.input_dataset, | |
repo_type="dataset", | |
title="✨ Retraining started!", | |
description=message, | |
token=HF_ACCESS_TOKEN, | |
) | |
NOTIFICATION_TEMPLATE = """\ | |
🌸 Hello there! | |
Following an update of [{input_dataset}](https://huggingface.co/datasets/{input_dataset}), an automatic re-training of [{input_model}](https://huggingface.co/{input_model}) has been scheduled on AutoTrain! | |
Please review and approve the project [here]({ui_url}/{project_id}/trainings) to start the training job. | |
(This is an automated message) | |
""" | |