sbrandeis HF staff commited on
Commit
55d6386
·
verified ·
1 Parent(s): 583a982

✨ Implement Auto-Retrain

Browse files
Files changed (9) hide show
  1. .gitignore +3 -0
  2. Dockerfile +16 -0
  3. README.md +2 -2
  4. config.json +6 -0
  5. home.html +16 -0
  6. requirements.txt +4 -0
  7. src/main.py +142 -0
  8. src/models.py +28 -0
  9. style.css +28 -0
.gitignore ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ .venv
2
+ .vscode
3
+ __pycache__
Dockerfile ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.9
2
+
3
+ RUN useradd -m -u 1000 user
4
+ USER user
5
+
6
+ ENV HOME=/home/user \
7
+ PATH=/home/user/.local/bin:$PATH
8
+
9
+ WORKDIR $HOME/app
10
+
11
+ COPY --chown=user requirements.txt requirements.txt
12
+ RUN pip install --no-cache-dir --upgrade -r requirements.txt
13
+
14
+ COPY --chown=user . .
15
+
16
+ CMD ["uvicorn", "src.main:app", "--host", "0.0.0.0", "--port", "7860"]
README.md CHANGED
@@ -1,6 +1,6 @@
1
  ---
2
- title: Actvie Learning Webhook
3
- emoji: 🏢
4
  colorFrom: yellow
5
  colorTo: red
6
  sdk: docker
 
1
  ---
2
+ title: Auto Re-Train
3
+ emoji:
4
  colorFrom: yellow
5
  colorTo: red
6
  sdk: docker
config.json ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ {
2
+ "target_namespace": "sbrandeis-test-org",
3
+ "input_dataset": "sbrandeis-test-org/input-dataset",
4
+ "input_model": "microsoft/resnet-50",
5
+ "autotrain_project_prefix": "auto-retrain-"
6
+ }
home.html ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!DOCTYPE html>
2
+ <html>
3
+ <head>
4
+ <meta charset="utf-8" />
5
+ <meta name="viewport" content="width=device-width" />
6
+ <title>Auto Re-Train</title>
7
+ <link rel="stylesheet" href="style.css" />
8
+ </head>
9
+ <body>
10
+ <div class="card">
11
+ <h1>Auto Re-Train webhook</h1>
12
+
13
+ <p>This is a webhook space to auto-retrain on model when a dataset changes.</p>
14
+ </div>
15
+ </body>
16
+ </html>
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ fastapi==0.74.*
2
+ requests==2.27.*
3
+ huggingface_hub==0.11.*
4
+ uvicorn[standard]==0.17.*
src/main.py ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import requests
3
+ from typing import Optional
4
+
5
+ from fastapi import FastAPI, Header, HTTPException, BackgroundTasks
6
+ from fastapi.responses import FileResponse
7
+ from huggingface_hub.hf_api import HfApi
8
+
9
+ from .models import config, WebhookPayload
10
+
11
+ WEBHOOK_SECRET = os.getenv("WEBHOOK_SECRET")
12
+ HF_ACCESS_TOKEN = os.getenv("HF_ACCESS_TOKEN")
13
+ AUTOTRAIN_API_URL = "https://api.autotrain.huggingface.co"
14
+ AUTOTRAIN_UI_URL = "https://ui.autotrain.huggingface.co"
15
+
16
+
17
+ app = FastAPI()
18
+
19
+ @app.get("/")
20
+ async def home():
21
+ return FileResponse("home.html")
22
+
23
+ @app.post("/webhook")
24
+ async def post_webhook(
25
+ payload: WebhookPayload,
26
+ task_queue: BackgroundTasks,
27
+ x_webhook_secret: Optional[str] = Header(default=None),
28
+ ):
29
+ if x_webhook_secret is None:
30
+ raise HTTPException(401)
31
+ if x_webhook_secret != WEBHOOK_SECRET:
32
+ raise HTTPException(403)
33
+ if not (
34
+ payload.event.action == "update"
35
+ and payload.event.scope.startswith("repo.content")
36
+ and payload.repo.name == config.input_dataset
37
+ and payload.repo.type == "dataset"
38
+ ):
39
+ # no-op
40
+ return {"processed": False}
41
+
42
+ task_queue.add_task(
43
+ schedule_retrain,
44
+ payload
45
+ )
46
+
47
+ return {"processed": True}
48
+
49
+
50
+ def schedule_retrain(payload: WebhookPayload):
51
+ # Create the autotrain project
52
+ try:
53
+ project = AutoTrain.create_project(payload)
54
+ AutoTrain.add_data(project_id=project["id"])
55
+ AutoTrain.start_processing(project_id=project["id"])
56
+ except requests.HTTPError as err:
57
+ print("ERROR while requesting AutoTrain API:")
58
+ print(f" code: {err.response.status_code}")
59
+ print(f" {err.response.json()}")
60
+ raise
61
+ # Notify in the community tab
62
+ notify_success(project["id"])
63
+
64
+ return {"processed": True}
65
+
66
+
67
+ class AutoTrain:
68
+ @staticmethod
69
+ def create_project(payload: WebhookPayload) -> dict:
70
+ project_resp = requests.post(
71
+ f"{AUTOTRAIN_API_URL}/projects/create",
72
+ json={
73
+ "username": config.target_namespace,
74
+ "proj_name": f"{config.autotrain_project_prefix}-{payload.repo.headSha[:7]}",
75
+ "task": 18, # image-multi-class-classification
76
+ "config": {
77
+ "hub-model": config.input_model,
78
+ "max_models": 1,
79
+ "language": "unk",
80
+ }
81
+ },
82
+ headers={
83
+ "Authorization": f"Bearer {HF_ACCESS_TOKEN}"
84
+ }
85
+ )
86
+ project_resp.raise_for_status()
87
+ return project_resp.json()
88
+
89
+ @staticmethod
90
+ def add_data(project_id:int):
91
+ requests.post(
92
+ f"{AUTOTRAIN_API_URL}/projects/{project_id}/data/dataset",
93
+ json={
94
+ "dataset_id": config.input_dataset,
95
+ "dataset_split": "train",
96
+ "split": 4,
97
+ "col_mapping": {
98
+ "image": "image",
99
+ "label": "target",
100
+ }
101
+ },
102
+ headers={
103
+ "Authorization": f"Bearer {HF_ACCESS_TOKEN}",
104
+ }
105
+ ).raise_for_status()
106
+
107
+ @staticmethod
108
+ def start_processing(project_id: int):
109
+ resp = requests.post(
110
+ f"{AUTOTRAIN_API_URL}/projects/{project_id}/data/start_processing",
111
+ headers={
112
+ "Authorization": f"Bearer {HF_ACCESS_TOKEN}",
113
+ }
114
+ )
115
+ resp.raise_for_status()
116
+ return resp
117
+
118
+
119
+ def notify_success(project_id: int):
120
+ message = NOTIFICATION_TEMPLATE.format(
121
+ input_model=config.input_model,
122
+ input_dataset=config.input_dataset,
123
+ project_id=project_id,
124
+ ui_url=AUTOTRAIN_UI_URL,
125
+ )
126
+ return HfApi(token=HF_ACCESS_TOKEN).create_discussion(
127
+ repo_id=config.input_dataset,
128
+ repo_type="dataset",
129
+ title="✨ Retraining started!",
130
+ description=message,
131
+ token=HF_ACCESS_TOKEN,
132
+ )
133
+
134
+ NOTIFICATION_TEMPLATE = """\
135
+ 🌸 Hello there!
136
+
137
+ Following an update of [{input_dataset}](https://huggingface.co/datasets/{input_dataset}), an automatic re-training of [{input_model}](https://huggingface.co/{input_model}) has been scheduled on AutoTrain!
138
+
139
+ Please review and approve the project [here]({ui_url}/{project_id}/trainings) to start the training job.
140
+
141
+ (This is an automated message)
142
+ """
src/models.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from pydantic import BaseModel
3
+ from typing import Literal
4
+
5
+ class Config(BaseModel):
6
+ target_namespace: str
7
+ input_dataset: str
8
+ input_model: str
9
+ autotrain_project_prefix: str
10
+
11
+
12
+ class WebhookPayloadEvent(BaseModel):
13
+ action: Literal["create", "update", "delete"]
14
+ scope: str
15
+
16
+ class WebhookPayloadRepo(BaseModel):
17
+ type: Literal["dataset", "model", "space"]
18
+ name: str
19
+ id: str
20
+ private: bool
21
+ headSha: str
22
+
23
+ class WebhookPayload(BaseModel):
24
+ event: WebhookPayloadEvent
25
+ repo: WebhookPayloadRepo
26
+
27
+
28
+ config = Config.parse_file(os.path.join(os.getcwd(), "config.json"))
style.css ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ body {
2
+ padding: 2rem;
3
+ font-family: -apple-system, BlinkMacSystemFont, "Arial", sans-serif;
4
+ }
5
+
6
+ h1 {
7
+ font-size: 16px;
8
+ margin-top: 0;
9
+ }
10
+
11
+ p {
12
+ color: rgb(107, 114, 128);
13
+ font-size: 15px;
14
+ margin-bottom: 10px;
15
+ margin-top: 5px;
16
+ }
17
+
18
+ .card {
19
+ max-width: 620px;
20
+ margin: 0 auto;
21
+ padding: 16px;
22
+ border: 1px solid lightgray;
23
+ border-radius: 16px;
24
+ }
25
+
26
+ .card p:last-child {
27
+ margin-bottom: 0;
28
+ }