File size: 3,541 Bytes
049266b
 
2c5db9f
 
049266b
 
 
 
 
 
 
ec78bff
 
049266b
 
2c5db9f
049266b
64f92ba
 
 
 
 
 
 
 
 
 
 
049266b
ec78bff
204126b
2c5db9f
049266b
 
e6c30c5
049266b
 
 
ec78bff
049266b
 
 
 
 
 
 
 
 
 
 
 
601925f
049266b
 
601925f
049266b
 
64f92ba
 
 
2c5db9f
ec78bff
c5643ee
64f92ba
dbcd12b
8dffc3b
64f92ba
049266b
 
 
 
 
64f92ba
049266b
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
import time
import streamlit as st
from annotated_text import annotated_text

from flair.data import Sentence
from flair.models import SequenceTagger

checkpoints = [
    "qanastek/pos-french",
]

colors = {'DET': '#b9d9a6', 'NFP': '#eddc92', 'ADJFP': '#95e9d7', 'AUX': '#e797db', 'VPPMS': '#9ff48b', 'ADV': '#ed92b4', 'PREP': '#decfa1', 'PDEMMS': '#ada7d7', 'NMS': '#85fad8', 'COSUB': '#8ba4f4', 'PINDMS': '#e7a498', 'PPOBJMS': '#e5c79a', 'VERB': '#eb94b6', 'DETFS': '#e698ae', 'NFS': '#d9d1a6', 'YPFOR': '#96e89f', 'VPPFS': '#e698c6', 'PUNCT': '#ddbfa2', 'DETMS': '#f788cd', 'PROPN': '#f19c8d', 'ADJMS': '#8ed5f0', 'PPER3FS': '#c4d8a6', 'ADJFS': '#e39bdc', 'COCO': '#8df1e2', 'NMP': '#d7f787', 'PREL': '#f986f0', 'PPER1S': '#878df8', 'ADJMP': '#83fe80', 'VPPMP': '#a6d8c9', 'DINTMS': '#d9a6cc', 'PPER3MS': '#a1deda', 'PPER3MP': '#8fefe1', 'PREF': '#e3c79b', 'ADJ': '#fb81fe', 'DINTFS': '#d5fe81', 'CHIF': '#8084ff', 'XFAMIL': '#dd80fe', 'PRELFS': '#9ce3e3', 'SYM': '#9fbddf', 'NOUN': '#dea1b5', 'MOTINC': '#93b8ec', 'PINDFS': '#f787a5', 'PPOBJMP': '#dca3d2', 'NUM': '#b2e897', 'PREFP': '#e39cd0', 'PDEMFS': '#d8a7cb', 'VPPFP': '#83d9fb', 'PPER3FP': '#a1ddaa', 'PPOBJFS': '#e9ca95', 'PINDMP': '#e897e3', 'PRON': '#e29dcc', 'PPOBJFP': '#86f9dc', 'PART': '#aa96e8', 'PDEMMP': '#b2d7a8', 'PRELMS': '#e39bde', 'PDEMFP': '#b1e599', 'PRELFP': '#bbe39b', 'INTJ': '#bde996', 'PREFS': '#b39be4', 'PINDFP': '#e2e897', 'PRELMP': '#a5c0da', 'PINTFS': '#ceff80', 'PPER2S': '#d5a2dd', 'VPPRE': '#e78af4', '<START>': '#e6a899', '<STOP>': '#9adde5'}

@st.cache(suppress_st_warning=True, allow_output_mutation=True)
def get_model(model_name):
    return SequenceTagger.load(model_name) # Load the model

def getPos(s: Sentence):
    texts = []
    labels = []
    for t in s.tokens:
        for label in t.annotation_layers.keys():
            texts.append(t.text)
            labels.append(t.get_labels(label)[0].value)          
    return texts, labels

def getDictFromPOS(texts, labels):
    return [{ "text": t, "label": l } for t, l in zip(texts, labels)]

def getAnnotatedFromPOS(texts, labels):
    return [(t,l,colors[l]) for t, l in zip(texts, labels)]

def main():

    st.title("🥖 French Part-Of-Speech Tagging")

    checkpoint = st.selectbox("Choose model", checkpoints)
    model = get_model(checkpoint)
    
    default_text = "George Washington est allé à Washington"
    input_text = st.text_area(
        label="Original text",
        value=default_text,
    )

    start = None
    if st.button("🧠 Compute"):
        start = time.time()
        with st.spinner("Search for Part-Of-Speech Tags 🔍"):
            
            # Build Sentence
            s = Sentence(input_text)

            # predict tags
            model.predict(s)

            try:

                texts, labels = getPos(s)
                
                st.header("Labels:")
                anns = getAnnotatedFromPOS(texts, labels)
                annotated_text(*anns)

                st.header("JSON:")
                st.json(getDictFromPOS(texts, labels))

            except Exception as e:
                st.error("Some error occured!" + str(e))
                st.stop()

    st.write("---")

    st.markdown(
        "Built by [Yanis Labrak](https://www.linkedin.com/in/yanis-labrak-8a7412145/) 🚀"
    )
    st.markdown(
        "_Source code made with [FlairNLP](https://github.com/flairNLP/flair)_"
    )

    if start is not None:
        st.text(f"prediction took {time.time() - start:.2f}s")


if __name__ == "__main__":
    main()