= commited on
Commit
1e9917c
1 Parent(s): cb5412b

requirements

Browse files
Files changed (2) hide show
  1. app.py +35 -8
  2. requirements.txt +70 -0
app.py CHANGED
@@ -7,26 +7,53 @@ from torch import nn
7
  from PIL import Image
8
  from model import create_effnet_v2_model
9
 
 
 
 
 
10
  class_names = ['Honda', 'Hyundai', 'Toyota']
11
 
12
  effnet_v2, transforms = create_effnet_v2_model(num_classes=len(class_names), weights_path="efficient_net_s_carvision_3.pth")
13
 
14
- def predict(model, image_path, device):
 
 
15
 
16
- image = Image.open(image_path)
17
  image = transforms(image).unsqueeze(0)
18
- image = image.to(device)
19
- output = model(image)
20
 
21
- model.eval()
22
  with torch.inference_mode():
23
  probs = torch.softmax(output, dim=1)
24
 
25
  pred_labels_and_probs = {class_names[i]: float(probs[0, i]) for i in range(len(class_names))}
26
 
 
 
 
 
 
 
 
 
 
 
27
 
28
- return pred_labels_and_probs
 
29
 
30
- print(predict(effnet_v2, "examples/Toyota_Tacoma_2017_36_18_270_35_6_75_70_212_19_RWD_5_4_Pickup_xQa.jpg", torch.device("cpu")))
 
 
 
 
 
 
 
 
 
31
 
32
- # print(predict(effnet_v2, "test.jpg", torch.device("cuda:0")))
 
 
7
  from PIL import Image
8
  from model import create_effnet_v2_model
9
 
10
+ import gradio as gr
11
+ import os
12
+ from timeit import default_timer as timer
13
+
14
  class_names = ['Honda', 'Hyundai', 'Toyota']
15
 
16
  effnet_v2, transforms = create_effnet_v2_model(num_classes=len(class_names), weights_path="efficient_net_s_carvision_3.pth")
17
 
18
+ def predict(image):
19
+
20
+ start_time = timer()
21
 
22
+ # image = Image.open(image_path)
23
  image = transforms(image).unsqueeze(0)
24
+ # image = image.to(device)
25
+ output = effnet_v2(image)
26
 
27
+ effnet_v2.eval()
28
  with torch.inference_mode():
29
  probs = torch.softmax(output, dim=1)
30
 
31
  pred_labels_and_probs = {class_names[i]: float(probs[0, i]) for i in range(len(class_names))}
32
 
33
+ pred_time = round(timer() - start_time, 5)
34
+ return pred_labels_and_probs, pred_time
35
+
36
+
37
+ ### 4. Gradio app ###
38
+
39
+ # Create title, description and article strings
40
+ title = "CarVision 🚗🚘🚙🏎️"
41
+ description = "An EfficientNetv2 model to classify cars as Honda, Hyundai or Toyota"
42
+ article = "Created by Akshay Ballal"
43
 
44
+ # Create examples list from "examples/" directory
45
+ example_list = [["examples/" + example] for example in os.listdir("examples")]
46
 
47
+ # Create the Gradio demo
48
+ demo = gr.Interface(fn=predict, # mapping function from input to output
49
+ inputs=gr.Image(type="pil"), # what are the inputs?
50
+ outputs=[gr.Label(num_top_classes=3, label="Predictions"), # what are the outputs?
51
+ gr.Number(label="Prediction time (s)")], # our fn has two outputs, therefore we have two outputs
52
+ # Create examples list from "examples/" directory
53
+ examples=example_list,
54
+ title=title,
55
+ description=description,
56
+ article=article)
57
 
58
+ # Launch the demo!
59
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ aiofiles==23.1.0
2
+ aiohttp==3.8.4
3
+ aiosignal==1.3.1
4
+ altair==5.0.1
5
+ anyio==3.7.0
6
+ async-timeout==4.0.2
7
+ attrs==23.1.0
8
+ certifi==2023.5.7
9
+ charset-normalizer==3.1.0
10
+ click==8.1.3
11
+ colorama==0.4.6
12
+ contourpy==1.1.0
13
+ cycler==0.11.0
14
+ exceptiongroup==1.1.2
15
+ fastapi==0.99.1
16
+ ffmpy==0.3.0
17
+ filelock==3.12.2
18
+ fonttools==4.40.0
19
+ frozenlist==1.3.3
20
+ fsspec==2023.6.0
21
+ gradio==3.35.2
22
+ gradio_client==0.2.7
23
+ h11==0.14.0
24
+ httpcore==0.17.2
25
+ httpx==0.24.1
26
+ huggingface-hub==0.15.1
27
+ idna==3.4
28
+ Jinja2==3.1.2
29
+ jsonschema==4.17.3
30
+ kiwisolver==1.4.4
31
+ linkify-it-py==2.0.2
32
+ markdown-it-py==2.2.0
33
+ MarkupSafe==2.1.3
34
+ matplotlib==3.7.1
35
+ mdit-py-plugins==0.3.3
36
+ mdurl==0.1.2
37
+ mpmath==1.3.0
38
+ multidict==6.0.4
39
+ networkx==3.1
40
+ numpy==1.25.0
41
+ orjson==3.9.1
42
+ packaging==23.1
43
+ pandas==2.0.3
44
+ Pillow==10.0.0
45
+ pydantic==1.10.10
46
+ pydub==0.25.1
47
+ Pygments==2.15.1
48
+ pyparsing==3.1.0
49
+ pyrsistent==0.19.3
50
+ python-dateutil==2.8.2
51
+ python-multipart==0.0.6
52
+ pytz==2023.3
53
+ PyYAML==6.0
54
+ requests==2.31.0
55
+ semantic-version==2.10.0
56
+ six==1.16.0
57
+ sniffio==1.3.0
58
+ starlette==0.27.0
59
+ sympy==1.12
60
+ toolz==0.12.0
61
+ torch==2.0.1
62
+ torchvision==0.15.2
63
+ tqdm==4.65.0
64
+ typing_extensions==4.7.1
65
+ tzdata==2023.3
66
+ uc-micro-py==1.0.2
67
+ urllib3==2.0.3
68
+ uvicorn==0.22.0
69
+ websockets==11.0.3
70
+ yarl==1.9.2