user-agent commited on
Commit
9b67e53
·
verified ·
1 Parent(s): d4a7da4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +7 -13
app.py CHANGED
@@ -10,15 +10,12 @@ import os
10
 
11
  token = os.environ.get("HUGGINGFACE_HUB_TOKEN")
12
 
13
-
14
-
15
  model_id = "thelabel/240903-image-tagging"
16
- config = AutoConfig.from_pretrained(model_id,token=token)
17
- model = AutoModelForImageClassification.from_pretrained(model_id,token=token)
18
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
19
  model.to(device)
20
 
21
- # Standard ViT image transforms
22
  image_transform = transforms.Compose([
23
  transforms.Resize((224, 224)),
24
  transforms.ToTensor(),
@@ -30,7 +27,7 @@ def load_image_from_url(url):
30
  response = requests.get(url, timeout=10)
31
  response.raise_for_status()
32
  return Image.open(BytesIO(response.content)).convert("RGB")
33
- except Exception as e:
34
  return None
35
 
36
  @spaces.GPU
@@ -55,9 +52,8 @@ def predict_tags(image_url, threshold=0.5):
55
  def gradio_predict(url, threshold):
56
  tags, error = predict_tags(url, threshold)
57
  if error:
58
- return error, None
59
- image = load_image_from_url(url)
60
- return "\n".join([f"{tag}: {score:.2f}" for tag, score in tags]), image
61
 
62
  demo = gr.Interface(
63
  fn=gradio_predict,
@@ -65,12 +61,10 @@ demo = gr.Interface(
65
  gr.Textbox(label="Image URL"),
66
  gr.Slider(0, 1, value=0.5, step=0.01, label="Threshold"),
67
  ],
68
- outputs=[
69
- gr.Textbox(label="Tags"),
70
- gr.Image(label="Preview", type="pil"),
71
- ],
72
  title="Image Tagging with ViT",
73
  description="Paste an image URL and get predicted tags using thelabel/240903-image-tagging model.",
74
  )
 
75
  if __name__ == "__main__":
76
  demo.launch()
 
10
 
11
  token = os.environ.get("HUGGINGFACE_HUB_TOKEN")
12
 
 
 
13
  model_id = "thelabel/240903-image-tagging"
14
+ config = AutoConfig.from_pretrained(model_id, token=token)
15
+ model = AutoModelForImageClassification.from_pretrained(model_id, token=token)
16
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
17
  model.to(device)
18
 
 
19
  image_transform = transforms.Compose([
20
  transforms.Resize((224, 224)),
21
  transforms.ToTensor(),
 
27
  response = requests.get(url, timeout=10)
28
  response.raise_for_status()
29
  return Image.open(BytesIO(response.content)).convert("RGB")
30
+ except Exception:
31
  return None
32
 
33
  @spaces.GPU
 
52
  def gradio_predict(url, threshold):
53
  tags, error = predict_tags(url, threshold)
54
  if error:
55
+ return error
56
+ return "\n".join([f"{tag}: {score:.2f}" for tag, score in tags])
 
57
 
58
  demo = gr.Interface(
59
  fn=gradio_predict,
 
61
  gr.Textbox(label="Image URL"),
62
  gr.Slider(0, 1, value=0.5, step=0.01, label="Threshold"),
63
  ],
64
+ outputs=gr.Textbox(label="Tags"),
 
 
 
65
  title="Image Tagging with ViT",
66
  description="Paste an image URL and get predicted tags using thelabel/240903-image-tagging model.",
67
  )
68
+
69
  if __name__ == "__main__":
70
  demo.launch()