File size: 2,149 Bytes
caacd68 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 |
import streamlit as st
from transformers import CamembertTokenizer, CamembertForSequenceClassification
import torch
import numpy as np
def load_tokenizer():
return CamembertTokenizer.from_pretrained("camembert-base")
def load_model():
return CamembertForSequenceClassification.from_pretrained("herelles/camembert-base-lupan")
# Define tokenizer:
tokenizer = load_tokenizer()
# Load model:
model = load_model()'cpu')
def prediction(segment_text):
test_ids = []
test_attention_mask = []
# Apply the tokenizer
encoding = tokenizer(segment_text, padding="longest", return_tensors="pt")
# Extract IDs and Attention Mask
test_ids =, dim = 0)
test_attention_mask =, dim = 0)
# Forward pass, calculate logit predictions
with torch.no_grad():
output = model('cpu'), token_type_ids = None, attention_mask ='cpu'))
return np.argmax(output.logits.cpu().numpy()).flatten().item()
def main():
st.header('Textual segments Hérelles prediction tool', divider='rainbow')
segment_text = st.text_area(
"Text to classify:",
"Article 1 : Occupations ou utilisations du sol interdites\n\n"
"1) Dans l’ensemble de la zone sont interdits :\n\n"
"Les constructions destinées à l’habitation ne dépendant pas d’une exploitation agricole autres\n"
"que celles visées à l’article 2 paragraphe 1).",
if st.button('Predict'):
pred_id = prediction(segment_text)
if pred_id == 0:
pred_label = 'Not pertinent'
elif pred_id == 1:
pred_label = 'Pertinent (Soft)'
elif pred_id == 2:
pred_label = 'Pertinent (Strict, Non-verifiable)'
elif pred_id == 3:
pred_label = 'Pertinent (Strict, Verifiable)'
st.write("Predicted Class: ", pred_label)
if __name__ == "__main__":