NimaKL commited on
Commit
35a1c4a
1 Parent(s): 6d90163

Upload 2 files

Browse files
Files changed (2) hide show
  1. app-st.py +192 -0
  2. requirements.txt +2 -1
app-st.py ADDED
@@ -0,0 +1,192 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import torch
3
+ import numpy as np
4
+ from transformers import AutoTokenizer
5
+ from transformers import BertForSequenceClassification
6
+
7
+
8
+ st.set_page_config(layout='wide', initial_sidebar_state='expanded')
9
+ col1, col2= st.columns(2)
10
+
11
+ with col1:
12
+ st.title("FireWatch")
13
+ st.markdown("PREDICT WHETHER HEAT SIGNATURES AROUND THE GLOBE ARE LIKELY TO BE FIRES!")
14
+ st.markdown("Traing Code at:")
15
+ st.markdown("https://colab.research.google.com/drive/1-IfOMJ-X8MKzwm3UjbJbK6RmhT7tk_ye?usp=sharing")
16
+ st.markdown("Try the Model Yourself at:")
17
+ st.markdown("https://colab.research.google.com/drive/1GmweeQrkzs0OXQ_KNZsWd1PQVRLCWDKi?usp=sharing")
18
+
19
+ st.markdown("## Sample Table")
20
+
21
+ table_html = """
22
+ <table style="border-collapse: collapse; width: 100%;">
23
+ <tr style="border: 1px solid orange;">
24
+ <th style="border: 1px solid orange; font-weight: bold;">Category</th>
25
+ <th style="border: 1px solid orange; font-weight: bold;">Latitude, Longitude, Brightness, FRP</th>
26
+ </tr>
27
+ <tr style="border: 1px solid orange;">
28
+ <td style="border: 1px solid orange;">Likely</td>
29
+ <td style="border: 1px solid orange;">-26.76123, 147.15512, 393.02, 203.63</td>
30
+ </tr>
31
+ <tr style="border: 1px solid orange;">
32
+ <td style="border: 1px solid orange;">Likely</td>
33
+ <td style="border: 1px solid orange;">-26.7598, 147.14514, 361.54, 79.4</td>
34
+ </tr>
35
+ <tr style="border: 1px solid orange;">
36
+ <td style="border: 1px solid orange;">Unlikely</td>
37
+ <td style="border: 1px solid orange;">-25.70059, 149.48932, 313.9, 5.15</td>
38
+ </tr>
39
+ <tr style="border: 1px solid orange;">
40
+ <td style="border: 1px solid orange;">Unlikely</td>
41
+ <td style="border: 1px solid orange;">-24.4318, 151.83102, 307.98, 8.79</td>
42
+ </tr>
43
+ <tr style="border: 1px solid orange;">
44
+ <td style="border: 1px solid orange;">Unlikely</td>
45
+ <td style="border: 1px solid orange;">-23.21878, 148.91298, 314.08, 7.4</td>
46
+ </tr>
47
+ <tr style="border: 1px solid orange;">
48
+ <td style="border: 1px solid orange;">Likely</td>
49
+ <td style="border: 1px solid orange;">7.87518, 19.9241, 316.32, 39.63</td>
50
+ </tr>
51
+ <tr style="border: 1px solid orange;">
52
+ <td style="border: 1px solid orange;">Unlikely</td>
53
+ <td style="border: 1px solid orange;">-20.10942, 148.14326, 314.39, 8.8</td>
54
+ </tr>
55
+ <tr style="border: 1px solid orange;">
56
+ <td style="border: 1px solid orange;">Unlikely</td>
57
+ <td style="border: 1px solid orange;">7.87772, 19.9048, 304.14, 13.43</td>
58
+ </tr>
59
+ <tr style="border: 1px solid orange;">
60
+ <td style="border: 1px solid orange;">Likely</td>
61
+ <td style="border: 1px solid orange;">-20.79866, 124.46834, 366.74, 89.06</td>
62
+ </tr>
63
+ </table>
64
+ """
65
+
66
+ st.markdown(table_html, unsafe_allow_html=True)
67
+ tree = """
68
+ <div class="pine-tree" style="width: 50%; margin: 0 auto;">
69
+ <div class="tree-top"></div>
70
+ <div class="tree-top2"></div>
71
+ <div class="tree-bottom">
72
+ <div class="trunk"></div>
73
+ </div>
74
+ </div>
75
+ <style>
76
+ .pine-tree {
77
+ width: 15vw;
78
+ height: 20vw;
79
+ position: relative;
80
+ display: flex;
81
+ justify-content: center;
82
+ align-items: center;
83
+ }
84
+ .tree-top {
85
+ width: 0;
86
+ height: 0;
87
+ border-left: 8vw solid transparent;
88
+ border-right: 8vw solid transparent;
89
+ border-bottom: 13vw solid green;
90
+ position: absolute;
91
+ top: 0;
92
+ left: 0;
93
+ right: 0;
94
+ margin: auto;
95
+ }
96
+ .tree-top2 {
97
+ width: 0;
98
+ height: 0;
99
+ border-left: 8vw solid transparent;
100
+ border-right: 8vw solid transparent;
101
+ border-bottom: 13vw solid green;
102
+ position: absolute;
103
+ top: 3vw;
104
+ left: 0;
105
+ right: 0;
106
+ margin: auto;
107
+ }
108
+ .tree-bottom {
109
+ width: 8vw;
110
+ height: 10vw;
111
+ background-color: brown;
112
+ position: absolute;
113
+ bottom: 0;
114
+ left: 0;
115
+ right: 0;
116
+ top: 21vw;
117
+ margin: auto;
118
+ }
119
+ .trunk {
120
+ width: 3vw;
121
+ height: 10vw;
122
+ background-color: brown;
123
+ position: absolute;
124
+ bottom: 0;
125
+ left: 0;
126
+ right: 0;
127
+ margin: auto;
128
+ }
129
+ </style>
130
+ """
131
+
132
+
133
+ with col2:
134
+ @st.cache(suppress_st_warning=True, allow_output_mutation=True)
135
+ def load_model(show_spinner=True):
136
+ MODEL_PATH = "NimaKL/FireWatch_tiny_75k"
137
+ model = BertForSequenceClassification.from_pretrained(MODEL_PATH)
138
+ return model
139
+
140
+
141
+
142
+ token_id = []
143
+ attention_masks = []
144
+ def preprocessing(input_text, tokenizer):
145
+ '''
146
+ Returns <class transformers.tokenization_utils_base.BatchEncoding> with the following fields:
147
+ - input_ids: list of token ids
148
+ - token_type_ids: list of token type ids
149
+ - attention_mask: list of indices (0,1) specifying which tokens should considered by the model (return_attention_mask = True).
150
+ '''
151
+ return tokenizer.encode_plus(
152
+ input_text,
153
+ add_special_tokens = True,
154
+ max_length = 16,
155
+ pad_to_max_length = True,
156
+ return_attention_mask = True,
157
+ return_tensors = 'pt'
158
+ )
159
+
160
+ def predict(new_sentence):
161
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
162
+ # We need Token IDs and Attention Mask for inference on the new sentence
163
+ test_ids = []
164
+ test_attention_mask = []
165
+ # Apply the tokenizer
166
+ encoding = preprocessing(new_sentence, tokenizer)
167
+ # Extract IDs and Attention Mask
168
+ test_ids.append(encoding['input_ids'])
169
+ test_attention_mask.append(encoding['attention_mask'])
170
+ test_ids = torch.cat(test_ids, dim = 0)
171
+ test_attention_mask = torch.cat(test_attention_mask, dim = 0)
172
+ # Forward pass, calculate logit predictions
173
+ with torch.no_grad():
174
+ output = model(test_ids.to(device), token_type_ids = None, attention_mask = test_attention_mask.to(device))
175
+ prediction = 'Likely' if np.argmax(output.logits.cpu().numpy()).flatten().item() == 1 else 'Unlikely'
176
+ pred = 'Predicted Class: '+ prediction
177
+ return pred
178
+
179
+ model = load_model()
180
+ tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
181
+ with col2:
182
+ st.markdown('## Enter Prediction Data in Correct Format "Latitude, Longtitude, Brightness, FRP"')
183
+ text = st.text_input('Predition Data: ', 'Example: 8.81064, -65.07661, 328.04, 18.76')
184
+ aButton = st.button('Predict')
185
+
186
+ if text or aButton:
187
+ with st.spinner('Wait for it...'):
188
+ st.success(predict(text))
189
+ st.markdown(tree, unsafe_allow_html=True)
190
+
191
+
192
+
requirements.txt CHANGED
@@ -1,2 +1,3 @@
1
  transformers
2
- torch
 
 
1
  transformers
2
+ torch
3
+ streamlit