auto-retrain / src /main.py
sbrandeis's picture
sbrandeis HF staff
✨ Implement Auto-Retrain
55d6386 verified
raw
history blame
3.7 kB
import os
import requests
from typing import Optional
from fastapi import FastAPI, Header, HTTPException, BackgroundTasks
from fastapi.responses import FileResponse
from huggingface_hub.hf_api import HfApi
from .models import config, WebhookPayload
WEBHOOK_SECRET = os.getenv("WEBHOOK_SECRET")
HF_ACCESS_TOKEN = os.getenv("HF_ACCESS_TOKEN")
AUTOTRAIN_API_URL = "https://api.autotrain.huggingface.co"
AUTOTRAIN_UI_URL = "https://ui.autotrain.huggingface.co"
app = FastAPI()
@app.get("/")
async def home():
return FileResponse("home.html")
@app.post("/webhook")
async def post_webhook(
payload: WebhookPayload,
task_queue: BackgroundTasks,
x_webhook_secret: Optional[str] = Header(default=None),
):
if x_webhook_secret is None:
raise HTTPException(401)
if x_webhook_secret != WEBHOOK_SECRET:
raise HTTPException(403)
if not (
payload.event.action == "update"
and payload.event.scope.startswith("repo.content")
and payload.repo.name == config.input_dataset
and payload.repo.type == "dataset"
):
# no-op
return {"processed": False}
task_queue.add_task(
schedule_retrain,
payload
)
return {"processed": True}
def schedule_retrain(payload: WebhookPayload):
# Create the autotrain project
try:
project = AutoTrain.create_project(payload)
AutoTrain.add_data(project_id=project["id"])
AutoTrain.start_processing(project_id=project["id"])
except requests.HTTPError as err:
print("ERROR while requesting AutoTrain API:")
print(f" code: {err.response.status_code}")
print(f" {err.response.json()}")
raise
# Notify in the community tab
notify_success(project["id"])
return {"processed": True}
class AutoTrain:
@staticmethod
def create_project(payload: WebhookPayload) -> dict:
project_resp = requests.post(
f"{AUTOTRAIN_API_URL}/projects/create",
json={
"username": config.target_namespace,
"proj_name": f"{config.autotrain_project_prefix}-{payload.repo.headSha[:7]}",
"task": 18, # image-multi-class-classification
"config": {
"hub-model": config.input_model,
"max_models": 1,
"language": "unk",
}
},
headers={
"Authorization": f"Bearer {HF_ACCESS_TOKEN}"
}
)
project_resp.raise_for_status()
return project_resp.json()
@staticmethod
def add_data(project_id:int):
requests.post(
f"{AUTOTRAIN_API_URL}/projects/{project_id}/data/dataset",
json={
"dataset_id": config.input_dataset,
"dataset_split": "train",
"split": 4,
"col_mapping": {
"image": "image",
"label": "target",
}
},
headers={
"Authorization": f"Bearer {HF_ACCESS_TOKEN}",
}
).raise_for_status()
@staticmethod
def start_processing(project_id: int):
resp = requests.post(
f"{AUTOTRAIN_API_URL}/projects/{project_id}/data/start_processing",
headers={
"Authorization": f"Bearer {HF_ACCESS_TOKEN}",
}
)
resp.raise_for_status()
return resp
def notify_success(project_id: int):
message = NOTIFICATION_TEMPLATE.format(
input_model=config.input_model,
input_dataset=config.input_dataset,
project_id=project_id,
ui_url=AUTOTRAIN_UI_URL,
)
return HfApi(token=HF_ACCESS_TOKEN).create_discussion(
repo_id=config.input_dataset,
repo_type="dataset",
title="✨ Retraining started!",
description=message,
token=HF_ACCESS_TOKEN,
)
NOTIFICATION_TEMPLATE = """\
🌸 Hello there!
Following an update of [{input_dataset}](https://huggingface.co/datasets/{input_dataset}), an automatic re-training of [{input_model}](https://huggingface.co/{input_model}) has been scheduled on AutoTrain!
Please review and approve the project [here]({ui_url}/{project_id}/trainings) to start the training job.
(This is an automated message)
"""