Mehmet Batuhan Duman commited on
Commit
b53b6d3
·
1 Parent(s): 4175bb8

Add Gradio app and requirements

Browse files
Files changed (2) hide show
  1. app.py +49 -4
  2. requirements.txt +6 -0
app.py CHANGED
@@ -1,7 +1,52 @@
 
 
1
  import gradio as gr
 
 
 
 
 
 
 
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
5
 
6
- iface = gr.Interface(fn=greet, inputs="text", outputs="text")
7
- iface.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
  import gradio as gr
4
+ from PIL import Image, ImageOps
5
+ import matplotlib.pyplot as plt
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ from torchvision import transforms
10
+ import os
11
 
12
+ # Add your model classes (Net and Net2) here.
 
13
 
14
+ # Loading model
15
+ model = None
16
+ model2 = None
17
+ model2_path = "model4.pth"
18
+
19
+ if os.path.exists(model2_path):
20
+ state_dict = torch.load(model2_path, map_location=torch.device('cpu'))
21
+ new_state_dict = {}
22
+ for key, value in state_dict.items():
23
+ new_key = key.replace("module.", "")
24
+ new_state_dict[new_key] = value
25
+
26
+ model = Net2()
27
+ model.load_state_dict(new_state_dict)
28
+ model.eval()
29
+
30
+ else:
31
+ print("Model file not found at", model2_path)
32
+
33
+ # Add the scanmap function here.
34
+
35
+ def process_image(image: Image.Image):
36
+ image_np = np.array(image)
37
+ start_time = time.time()
38
+ heatmap = scanmap(image_np, model)
39
+ elapsed_time = time.time() - start_time
40
+ heatmap_img = Image.fromarray(np.uint8(plt.cm.hot(heatmap) * 255)).convert('RGB')
41
+ heatmap_img = heatmap_img.resize(image.size)
42
+
43
+ return heatmap_img, elapsed_time
44
+
45
+ inputs = gr.inputs.Image(label="Upload Image")
46
+ outputs = [
47
+ gr.outputs.Image(label="Heatmap"),
48
+ gr.outputs.Textbox(label="Elapsed Time (seconds)")
49
+ ]
50
+
51
+ iface = gr.Interface(fn=process_image, inputs=inputs, outputs=outputs, title="ShipNet Heatmap")
52
+ iface.launch()
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ gradio
2
+ torch
3
+ torchvision
4
+ numpy
5
+ Pillow
6
+ matplotlib