Kr1n3 commited on
Commit
87bb394
1 Parent(s): 77f5256

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +22 -30
app.py CHANGED
@@ -1,36 +1,28 @@
1
- import torch
2
  import gradio as gr
3
- from huggingface_hub import hf_hub_download
4
- from PIL import Image
5
-
6
- REPO_ID = "https://huggingface.co/spaces/Kr1n3/Fashion-Items-Classification"
7
- FILENAME = "best.pt"
8
-
9
- yolov5_weights = hf_hub_download(repo_id=REPO_ID, filename=FILENAME)
10
 
11
- model = torch.hub.load('ultralytics/yolov5', 'custom', path=yolov5_weights, force_reload=True) # local repo
 
 
 
 
 
12
 
13
- def object_detection(im, size=640):
14
- results = model(im) # inference
15
- #results.print() # print results to screen
16
- #results.show() # display results
17
- #results.save() # save as results1.jpg, results2.jpg... etc.
18
- results.render() # updates results.imgs with boxes and labels
19
- return Image.fromarray(results.imgs[0])
20
 
21
- title = "Fashion Items Classification"
22
- description = """Esse modelo é uma pequena demonstração baseada em uma análise de cerca de 60 imagens somente. Para resultados mais confiáveis e genéricos, são necessários mais exemplos (imagens).
23
- """
24
 
25
- image = gr.inputs.Image(shape=(640, 640), image_mode="RGB", source="upload", label="Imagem", optional=False)
26
- outputs = gr.outputs.Image(type="pil", label="Output Image")
 
 
 
 
 
 
27
 
28
- gr.Interface(
29
- fn=object_classification,
30
- inputs=image,
31
- outputs=outputs,
32
- title=title,
33
- description=description,
34
- examples=[["https://github.com/Kr1n3/MPC_2022/blob/main/dataset/bag_01.jpg"], ["https://github.com/Kr1n3/MPC_2022/blob/main/dataset/bag_18.JPG?raw=true"],
35
- ["https://github.com/Kr1n3/MPC_2022/blob/main/dataset/pants_30.jpeg?raw=true"], ["https://github.com/Kr1n3/MPC_2022/blob/main/dataset/pants_33.jpg?raw=true"]],
36
- ).launch()
 
 
1
  import gradio as gr
2
+ import torch
3
+ from torch import nn
4
+ from torch.nn import functional as F
5
+ import torchvision
6
+ from torchvision import transforms
 
 
7
 
8
+ model= torch.jit.load('best.pt')
9
+ data_transform1=transforms.Compose([
10
+ transforms.Resize((640,640)),
11
+ transforms.ToTensor(),
12
+ transforms.Normalize((0.485,0.456,0.406),(0.229,0.224,0.225))
13
+ ])
14
 
15
+ title = " Fashion Items Classification"
 
 
 
 
 
 
16
 
17
+ examples=[['https://github.com/Kr1n3/MPC_2022/blob/main/dataset/pants_33.jpg?raw=true'],['https://github.com/Kr1n3/MPC_2022/blob/main/dataset/pants_30.jpeg?raw=true'],['https://github.com/Kr1n3/MPC_2022/blob/main/dataset/bag_01.jpg?raw=true']]
 
 
18
 
19
+ classes=['Bags','Dress','Pants','Shoes','Skirt']
20
+ def predict(img):
21
+ imag=data_transform1(img)
22
+ inp =imag.unsqueeze(0)
23
+ outputs=model(inp)
24
+ pred=F.softmax(outputs[0], dim=0).cpu().data.numpy()
25
+ confidences = {classes[i]:(float(pred[i])) for i in range(5)}
26
+ return confidences
27
 
28
+ gr.Interface(predict,gr.inputs.Image(type='pil'),title=title,examples=examples,outputs='label').launch(debug=True)