koptelovmax commited on
Commit
caacd68
1 Parent(s): 3f4e0cd

Add application file

Browse files
Files changed (2) hide show
  1. app.py +68 -0
  2. requirements.txt +5 -0
app.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from transformers import CamembertTokenizer, CamembertForSequenceClassification
3
+ import torch
4
+ import numpy as np
5
+
6
+ @st.cache_resource
7
+ def load_tokenizer():
8
+ return CamembertTokenizer.from_pretrained("camembert-base")
9
+
10
+ @st.cache_resource
11
+ def load_model():
12
+ return CamembertForSequenceClassification.from_pretrained("herelles/camembert-base-lupan")
13
+
14
+ # Define tokenizer:
15
+ tokenizer = load_tokenizer()
16
+
17
+ # Load model:
18
+ model = load_model()
19
+ model.to('cpu')
20
+
21
+ def prediction(segment_text):
22
+ test_ids = []
23
+ test_attention_mask = []
24
+
25
+ # Apply the tokenizer
26
+ encoding = tokenizer(segment_text, padding="longest", return_tensors="pt")
27
+
28
+ # Extract IDs and Attention Mask
29
+ test_ids.append(encoding['input_ids'])
30
+ test_attention_mask.append(encoding['attention_mask'])
31
+ test_ids = torch.cat(test_ids, dim = 0)
32
+ test_attention_mask = torch.cat(test_attention_mask, dim = 0)
33
+
34
+ # Forward pass, calculate logit predictions
35
+ with torch.no_grad():
36
+ output = model(test_ids.to('cpu'), token_type_ids = None, attention_mask = test_attention_mask.to('cpu'))
37
+
38
+ return np.argmax(output.logits.cpu().numpy()).flatten().item()
39
+
40
+ def main():
41
+ st.header('Textual segments Hérelles prediction tool', divider='rainbow')
42
+
43
+ segment_text = st.text_area(
44
+ "Text to classify:",
45
+ "Article 1 : Occupations ou utilisations du sol interdites\n\n"
46
+ "1) Dans l’ensemble de la zone sont interdits :\n\n"
47
+ "Les constructions destinées à l’habitation ne dépendant pas d’une exploitation agricole autres\n"
48
+ "que celles visées à l’article 2 paragraphe 1).",
49
+ height=170,
50
+ )
51
+
52
+ if st.button('Predict'):
53
+ pred_id = prediction(segment_text)
54
+
55
+ if pred_id == 0:
56
+ pred_label = 'Not pertinent'
57
+ elif pred_id == 1:
58
+ pred_label = 'Pertinent (Soft)'
59
+ elif pred_id == 2:
60
+ pred_label = 'Pertinent (Strict, Non-verifiable)'
61
+ elif pred_id == 3:
62
+ pred_label = 'Pertinent (Strict, Verifiable)'
63
+
64
+ st.write("Predicted Class: ", pred_label)
65
+
66
+ if __name__ == "__main__":
67
+ main()
68
+
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ streamlit
2
+ transformers
3
+ sentencepiece
4
+ torch
5
+ numpy