์ •์ •๋ฏผ commited on
Commit
34871a0
ยท
1 Parent(s): 486c8da

Add : Application file and Req. file

Browse files
Files changed (2) hide show
  1. app.py +63 -0
  2. requirements.txt +3 -0
app.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import requests
3
+ import gradio as gr
4
+ from PIL import Image
5
+
6
+ from transformers import AutoImageProcessor, ResNetForImageClassification
7
+
8
+ target_folder = "JungminChung/India_ResNet"
9
+
10
+ def load_model_and_preprocessor(target_folder):
11
+ model = ResNetForImageClassification.from_pretrained(target_folder)
12
+ image_processor = AutoImageProcessor.from_pretrained(target_folder)
13
+ return model, image_processor
14
+
15
+ def fetch_image(url):
16
+ headers = {
17
+ 'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/58.0.3029.110 Safari/537.36'
18
+ }
19
+ image_raw = requests.get(url, headers=headers, stream=True).raw
20
+ image = Image.open(image_raw)
21
+
22
+ return image
23
+
24
+ def infer_image(image, model, image_processor, k):
25
+ processed_img = image_processor(images=image.convert("RGB"), return_tensors="pt")
26
+
27
+ with torch.no_grad():
28
+ outputs = model(**processed_img)
29
+ logits = outputs.logits
30
+
31
+ prob = torch.nn.functional.softmax(logits, dim=-1)
32
+ topk_prob, topk_indices = torch.topk(prob, k=k)
33
+
34
+ res = ""
35
+ for idx, (prob, index) in enumerate(zip(topk_prob[0], topk_indices[0])):
36
+ res += f"{idx+1}. {model.config.id2label[index.item()]:<15} ({prob.item()*100:.2f} %) \n"
37
+ return res
38
+
39
+ def infer(url, k, target_folder=target_folder):
40
+ try :
41
+ image = fetch_image(url)
42
+ model, image_processor = load_model_and_preprocessor(target_folder)
43
+ res = infer_image(image, model, image_processor, k)
44
+ except :
45
+ image = Image.new('RGB', (224, 224))
46
+ res = "์ด๋ฏธ์ง€๋ฅผ ๋ถˆ๋Ÿฌ์˜ค๋Š”๋ฐ ๋ฌธ์ œ๊ฐ€ ์žˆ๋‚˜๋ด์š”. ๋‹ค๋ฅธ ์ด๋ฏธ์ง€ url๋กœ ๋‹ค์‹œ ์‹œ๋„ํ•ด์ฃผ์„ธ์š”."
47
+ return image, res
48
+
49
+ demo = gr.Interface(
50
+ fn=infer,
51
+ inputs=[
52
+ gr.Textbox(value="https://encrypted-tbn0.gstatic.com/images?q=tbn:ANd9GcRpE-UHBp8ZufNUd3BKw8gtIxSe3IUwspOfqw&s",
53
+ label="Image URL"),
54
+ gr.Slider(minimum=0, maximum=20, step=1, value=3, label="์ƒ์œ„ ๋ช‡๊ฐœ๊นŒ์ง€ ๋ณด์—ฌ์ค„๊นŒ์š”?")
55
+ ],
56
+ outputs=[
57
+ gr.Image(type="pil", label="์ž…๋ ฅ ์ด๋ฏธ์ง€"),
58
+ gr.Textbox(label="์ข…๋ฅ˜ (ํ™•๋ฅ )")
59
+ ],
60
+ )
61
+
62
+ demo.launch()
63
+ # demo.launch(share=True)
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ torch
2
+ transformers
3
+ Pillow