File size: 2,149 Bytes
caacd68
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import streamlit as st
from transformers import CamembertTokenizer, CamembertForSequenceClassification
import torch
import numpy as np

@st.cache_resource
def load_tokenizer():
    return CamembertTokenizer.from_pretrained("camembert-base")

@st.cache_resource
def load_model():
    return CamembertForSequenceClassification.from_pretrained("herelles/camembert-base-lupan")

# Define tokenizer:
tokenizer = load_tokenizer()

# Load model:
model = load_model()
model.to('cpu')

def prediction(segment_text):    
    test_ids = []
    test_attention_mask = []
    
    # Apply the tokenizer
    encoding = tokenizer(segment_text, padding="longest", return_tensors="pt")
    
    # Extract IDs and Attention Mask
    test_ids.append(encoding['input_ids'])
    test_attention_mask.append(encoding['attention_mask'])
    test_ids = torch.cat(test_ids, dim = 0)
    test_attention_mask = torch.cat(test_attention_mask, dim = 0)
    
    # Forward pass, calculate logit predictions
    with torch.no_grad():
      output = model(test_ids.to('cpu'), token_type_ids = None, attention_mask = test_attention_mask.to('cpu'))
    
    return np.argmax(output.logits.cpu().numpy()).flatten().item()

def main():
    st.header('Textual segments Hérelles prediction tool', divider='rainbow')
    
    segment_text = st.text_area(
    "Text to classify:",
    "Article 1 : Occupations ou utilisations du sol interdites\n\n"
    "1) Dans l’ensemble de la zone sont interdits :\n\n"
    "Les constructions destinées à l’habitation ne dépendant pas d’une exploitation agricole autres\n"
    "que celles visées à l’article 2 paragraphe 1).",
    height=170,
    )
        
    if st.button('Predict'):
        pred_id = prediction(segment_text)
        
        if pred_id == 0:
          pred_label = 'Not pertinent'
        elif pred_id == 1:
          pred_label = 'Pertinent (Soft)'
        elif pred_id == 2:
          pred_label = 'Pertinent (Strict, Non-verifiable)'
        elif pred_id == 3:
          pred_label = 'Pertinent (Strict, Verifiable)'
    
        st.write("Predicted Class: ", pred_label)
    
if __name__ == "__main__":
    main()