Spaces:
Runtime error
Runtime error
Commit
·
a776cb5
1
Parent(s):
de33a84
raise error if model id not valid
Browse files
app.py
CHANGED
@@ -11,6 +11,7 @@ import os
|
|
11 |
import backoff
|
12 |
from functools import lru_cache
|
13 |
from huggingface_hub import list_models, ModelFilter
|
|
|
14 |
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
|
15 |
|
16 |
|
@@ -64,17 +65,23 @@ def return_random_sample(k=27):
|
|
64 |
images = dataset[sample]["image"]
|
65 |
return [resize_image(image).convert("RGB") for image in images]
|
66 |
|
|
|
67 |
@lru_cache()
|
68 |
def get_valid_hub_image_classification_model_ids():
|
69 |
models = list_models(limit=None, filter=ModelFilter(task="image-classification"))
|
70 |
return {model.id for model in models}
|
71 |
|
|
|
72 |
def predict_subset(model_id, token):
|
73 |
-
API_URL = f"https://api-inference.huggingface.co/models/{model_id}"
|
74 |
-
headers = {"Authorization": f"Bearer {token}"}
|
75 |
valid_model_ids = get_valid_hub_image_classification_model_ids()
|
76 |
if model_id not in valid_model_ids:
|
77 |
-
gr.Error(
|
|
|
|
|
|
|
|
|
|
|
|
|
78 |
@backoff.on_predicate(backoff.expo, lambda x: x.status_code == 503, max_time=30)
|
79 |
def _query(url):
|
80 |
r = requests.post(API_URL, headers=headers, data=url)
|
|
|
11 |
import backoff
|
12 |
from functools import lru_cache
|
13 |
from huggingface_hub import list_models, ModelFilter
|
14 |
+
|
15 |
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
|
16 |
|
17 |
|
|
|
65 |
images = dataset[sample]["image"]
|
66 |
return [resize_image(image).convert("RGB") for image in images]
|
67 |
|
68 |
+
|
69 |
@lru_cache()
|
70 |
def get_valid_hub_image_classification_model_ids():
|
71 |
models = list_models(limit=None, filter=ModelFilter(task="image-classification"))
|
72 |
return {model.id for model in models}
|
73 |
|
74 |
+
|
75 |
def predict_subset(model_id, token):
|
|
|
|
|
76 |
valid_model_ids = get_valid_hub_image_classification_model_ids()
|
77 |
if model_id not in valid_model_ids:
|
78 |
+
raise gr.Error(
|
79 |
+
f"model_id {model_id} is not a valid image classification model id"
|
80 |
+
)
|
81 |
+
|
82 |
+
API_URL = f"https://api-inference.huggingface.co/models/{model_id}"
|
83 |
+
headers = {"Authorization": f"Bearer {token}"}
|
84 |
+
|
85 |
@backoff.on_predicate(backoff.expo, lambda x: x.status_code == 503, max_time=30)
|
86 |
def _query(url):
|
87 |
r = requests.post(API_URL, headers=headers, data=url)
|