Spaces:
Runtime error
Runtime error
✨ Implement Auto-Retrain
Browse files- .gitignore +3 -0
- Dockerfile +16 -0
- README.md +2 -2
- config.json +6 -0
- home.html +16 -0
- requirements.txt +4 -0
- src/main.py +142 -0
- src/models.py +28 -0
- style.css +28 -0
.gitignore
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
.venv
|
2 |
+
.vscode
|
3 |
+
__pycache__
|
Dockerfile
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
FROM python:3.9
|
2 |
+
|
3 |
+
RUN useradd -m -u 1000 user
|
4 |
+
USER user
|
5 |
+
|
6 |
+
ENV HOME=/home/user \
|
7 |
+
PATH=/home/user/.local/bin:$PATH
|
8 |
+
|
9 |
+
WORKDIR $HOME/app
|
10 |
+
|
11 |
+
COPY --chown=user requirements.txt requirements.txt
|
12 |
+
RUN pip install --no-cache-dir --upgrade -r requirements.txt
|
13 |
+
|
14 |
+
COPY --chown=user . .
|
15 |
+
|
16 |
+
CMD ["uvicorn", "src.main:app", "--host", "0.0.0.0", "--port", "7860"]
|
README.md
CHANGED
@@ -1,6 +1,6 @@
|
|
1 |
---
|
2 |
-
title:
|
3 |
-
emoji:
|
4 |
colorFrom: yellow
|
5 |
colorTo: red
|
6 |
sdk: docker
|
|
|
1 |
---
|
2 |
+
title: Auto Re-Train
|
3 |
+
emoji: ♻
|
4 |
colorFrom: yellow
|
5 |
colorTo: red
|
6 |
sdk: docker
|
config.json
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"target_namespace": "sbrandeis-test-org",
|
3 |
+
"input_dataset": "sbrandeis-test-org/input-dataset",
|
4 |
+
"input_model": "microsoft/resnet-50",
|
5 |
+
"autotrain_project_prefix": "auto-retrain-"
|
6 |
+
}
|
home.html
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<!DOCTYPE html>
|
2 |
+
<html>
|
3 |
+
<head>
|
4 |
+
<meta charset="utf-8" />
|
5 |
+
<meta name="viewport" content="width=device-width" />
|
6 |
+
<title>Auto Re-Train</title>
|
7 |
+
<link rel="stylesheet" href="style.css" />
|
8 |
+
</head>
|
9 |
+
<body>
|
10 |
+
<div class="card">
|
11 |
+
<h1>Auto Re-Train webhook</h1>
|
12 |
+
|
13 |
+
<p>This is a webhook space to auto-retrain on model when a dataset changes.</p>
|
14 |
+
</div>
|
15 |
+
</body>
|
16 |
+
</html>
|
requirements.txt
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
fastapi==0.74.*
|
2 |
+
requests==2.27.*
|
3 |
+
huggingface_hub==0.11.*
|
4 |
+
uvicorn[standard]==0.17.*
|
src/main.py
ADDED
@@ -0,0 +1,142 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import requests
|
3 |
+
from typing import Optional
|
4 |
+
|
5 |
+
from fastapi import FastAPI, Header, HTTPException, BackgroundTasks
|
6 |
+
from fastapi.responses import FileResponse
|
7 |
+
from huggingface_hub.hf_api import HfApi
|
8 |
+
|
9 |
+
from .models import config, WebhookPayload
|
10 |
+
|
11 |
+
WEBHOOK_SECRET = os.getenv("WEBHOOK_SECRET")
|
12 |
+
HF_ACCESS_TOKEN = os.getenv("HF_ACCESS_TOKEN")
|
13 |
+
AUTOTRAIN_API_URL = "https://api.autotrain.huggingface.co"
|
14 |
+
AUTOTRAIN_UI_URL = "https://ui.autotrain.huggingface.co"
|
15 |
+
|
16 |
+
|
17 |
+
app = FastAPI()
|
18 |
+
|
19 |
+
@app.get("/")
|
20 |
+
async def home():
|
21 |
+
return FileResponse("home.html")
|
22 |
+
|
23 |
+
@app.post("/webhook")
|
24 |
+
async def post_webhook(
|
25 |
+
payload: WebhookPayload,
|
26 |
+
task_queue: BackgroundTasks,
|
27 |
+
x_webhook_secret: Optional[str] = Header(default=None),
|
28 |
+
):
|
29 |
+
if x_webhook_secret is None:
|
30 |
+
raise HTTPException(401)
|
31 |
+
if x_webhook_secret != WEBHOOK_SECRET:
|
32 |
+
raise HTTPException(403)
|
33 |
+
if not (
|
34 |
+
payload.event.action == "update"
|
35 |
+
and payload.event.scope.startswith("repo.content")
|
36 |
+
and payload.repo.name == config.input_dataset
|
37 |
+
and payload.repo.type == "dataset"
|
38 |
+
):
|
39 |
+
# no-op
|
40 |
+
return {"processed": False}
|
41 |
+
|
42 |
+
task_queue.add_task(
|
43 |
+
schedule_retrain,
|
44 |
+
payload
|
45 |
+
)
|
46 |
+
|
47 |
+
return {"processed": True}
|
48 |
+
|
49 |
+
|
50 |
+
def schedule_retrain(payload: WebhookPayload):
|
51 |
+
# Create the autotrain project
|
52 |
+
try:
|
53 |
+
project = AutoTrain.create_project(payload)
|
54 |
+
AutoTrain.add_data(project_id=project["id"])
|
55 |
+
AutoTrain.start_processing(project_id=project["id"])
|
56 |
+
except requests.HTTPError as err:
|
57 |
+
print("ERROR while requesting AutoTrain API:")
|
58 |
+
print(f" code: {err.response.status_code}")
|
59 |
+
print(f" {err.response.json()}")
|
60 |
+
raise
|
61 |
+
# Notify in the community tab
|
62 |
+
notify_success(project["id"])
|
63 |
+
|
64 |
+
return {"processed": True}
|
65 |
+
|
66 |
+
|
67 |
+
class AutoTrain:
|
68 |
+
@staticmethod
|
69 |
+
def create_project(payload: WebhookPayload) -> dict:
|
70 |
+
project_resp = requests.post(
|
71 |
+
f"{AUTOTRAIN_API_URL}/projects/create",
|
72 |
+
json={
|
73 |
+
"username": config.target_namespace,
|
74 |
+
"proj_name": f"{config.autotrain_project_prefix}-{payload.repo.headSha[:7]}",
|
75 |
+
"task": 18, # image-multi-class-classification
|
76 |
+
"config": {
|
77 |
+
"hub-model": config.input_model,
|
78 |
+
"max_models": 1,
|
79 |
+
"language": "unk",
|
80 |
+
}
|
81 |
+
},
|
82 |
+
headers={
|
83 |
+
"Authorization": f"Bearer {HF_ACCESS_TOKEN}"
|
84 |
+
}
|
85 |
+
)
|
86 |
+
project_resp.raise_for_status()
|
87 |
+
return project_resp.json()
|
88 |
+
|
89 |
+
@staticmethod
|
90 |
+
def add_data(project_id:int):
|
91 |
+
requests.post(
|
92 |
+
f"{AUTOTRAIN_API_URL}/projects/{project_id}/data/dataset",
|
93 |
+
json={
|
94 |
+
"dataset_id": config.input_dataset,
|
95 |
+
"dataset_split": "train",
|
96 |
+
"split": 4,
|
97 |
+
"col_mapping": {
|
98 |
+
"image": "image",
|
99 |
+
"label": "target",
|
100 |
+
}
|
101 |
+
},
|
102 |
+
headers={
|
103 |
+
"Authorization": f"Bearer {HF_ACCESS_TOKEN}",
|
104 |
+
}
|
105 |
+
).raise_for_status()
|
106 |
+
|
107 |
+
@staticmethod
|
108 |
+
def start_processing(project_id: int):
|
109 |
+
resp = requests.post(
|
110 |
+
f"{AUTOTRAIN_API_URL}/projects/{project_id}/data/start_processing",
|
111 |
+
headers={
|
112 |
+
"Authorization": f"Bearer {HF_ACCESS_TOKEN}",
|
113 |
+
}
|
114 |
+
)
|
115 |
+
resp.raise_for_status()
|
116 |
+
return resp
|
117 |
+
|
118 |
+
|
119 |
+
def notify_success(project_id: int):
|
120 |
+
message = NOTIFICATION_TEMPLATE.format(
|
121 |
+
input_model=config.input_model,
|
122 |
+
input_dataset=config.input_dataset,
|
123 |
+
project_id=project_id,
|
124 |
+
ui_url=AUTOTRAIN_UI_URL,
|
125 |
+
)
|
126 |
+
return HfApi(token=HF_ACCESS_TOKEN).create_discussion(
|
127 |
+
repo_id=config.input_dataset,
|
128 |
+
repo_type="dataset",
|
129 |
+
title="✨ Retraining started!",
|
130 |
+
description=message,
|
131 |
+
token=HF_ACCESS_TOKEN,
|
132 |
+
)
|
133 |
+
|
134 |
+
NOTIFICATION_TEMPLATE = """\
|
135 |
+
🌸 Hello there!
|
136 |
+
|
137 |
+
Following an update of [{input_dataset}](https://huggingface.co/datasets/{input_dataset}), an automatic re-training of [{input_model}](https://huggingface.co/{input_model}) has been scheduled on AutoTrain!
|
138 |
+
|
139 |
+
Please review and approve the project [here]({ui_url}/{project_id}/trainings) to start the training job.
|
140 |
+
|
141 |
+
(This is an automated message)
|
142 |
+
"""
|
src/models.py
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from pydantic import BaseModel
|
3 |
+
from typing import Literal
|
4 |
+
|
5 |
+
class Config(BaseModel):
|
6 |
+
target_namespace: str
|
7 |
+
input_dataset: str
|
8 |
+
input_model: str
|
9 |
+
autotrain_project_prefix: str
|
10 |
+
|
11 |
+
|
12 |
+
class WebhookPayloadEvent(BaseModel):
|
13 |
+
action: Literal["create", "update", "delete"]
|
14 |
+
scope: str
|
15 |
+
|
16 |
+
class WebhookPayloadRepo(BaseModel):
|
17 |
+
type: Literal["dataset", "model", "space"]
|
18 |
+
name: str
|
19 |
+
id: str
|
20 |
+
private: bool
|
21 |
+
headSha: str
|
22 |
+
|
23 |
+
class WebhookPayload(BaseModel):
|
24 |
+
event: WebhookPayloadEvent
|
25 |
+
repo: WebhookPayloadRepo
|
26 |
+
|
27 |
+
|
28 |
+
config = Config.parse_file(os.path.join(os.getcwd(), "config.json"))
|
style.css
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
body {
|
2 |
+
padding: 2rem;
|
3 |
+
font-family: -apple-system, BlinkMacSystemFont, "Arial", sans-serif;
|
4 |
+
}
|
5 |
+
|
6 |
+
h1 {
|
7 |
+
font-size: 16px;
|
8 |
+
margin-top: 0;
|
9 |
+
}
|
10 |
+
|
11 |
+
p {
|
12 |
+
color: rgb(107, 114, 128);
|
13 |
+
font-size: 15px;
|
14 |
+
margin-bottom: 10px;
|
15 |
+
margin-top: 5px;
|
16 |
+
}
|
17 |
+
|
18 |
+
.card {
|
19 |
+
max-width: 620px;
|
20 |
+
margin: 0 auto;
|
21 |
+
padding: 16px;
|
22 |
+
border: 1px solid lightgray;
|
23 |
+
border-radius: 16px;
|
24 |
+
}
|
25 |
+
|
26 |
+
.card p:last-child {
|
27 |
+
margin-bottom: 0;
|
28 |
+
}
|