File size: 7,927 Bytes
4ee899c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fd07025
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
import streamlit as st
from spacy import displacy
from Model.NER.VLSP2021.Predict_Ner import ViTagger
import re
from thunghiemxuly import save_uploaded_image,convert_text_to_txt,add_string_to_txt

import os
from transformers import AutoTokenizer, BertConfig
from Model.MultimodelNER.VLSP2021.train_umt_2021 import load_model,predict
from Model.MultimodelNER.Ner_processing import format_predictions,process_predictions,combine_entities,remove_B_prefix,combine_i_tags

from Model.MultimodelNER.predict import get_test_examples_predict
from Model.MultimodelNER import resnet as resnet
from Model.MultimodelNER.resnet_utils import myResnet
import torch
import numpy as np
from Model.MultimodelNER.VLSP2021.dataset_roberta import MNERProcessor_2021


CONFIG_NAME = 'bert_config.json'
WEIGHTS_NAME = 'pytorch_model.bin'
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


net = getattr(resnet, 'resnet152')()
net.load_state_dict(torch.load(os.path.join('/Model/Resnet/', 'resnet152.pth')))
encoder = myResnet(net, True, device)
def process_text(text):
    # Loại bỏ dấu cách thừa và dấu cách ở đầu và cuối văn bản
    processed_text = re.sub(r'\s+', ' ', text.strip())
    return processed_text



def show_mner_2021():
    multimodal_text = st.text_area("Enter your text for MNER:", height=300)
    multimodal_text = process_text(multimodal_text)  # Xử lý văn bản
    image = st.file_uploader("Upload an image (only jpg):", type=["jpg"])
    if st.button("Process Multimodal NER"):
            save_image = '/Model/MultimodelNER/VLSP2021/Image'
            save_txt = '/Model/MultimodelNER/VLSP2021/Filetxt/test.txt'
            image_name = image.name
            save_uploaded_image(image, save_image)
            convert_text_to_txt(multimodal_text, save_txt)
            add_string_to_txt(image_name, save_txt)
            st.image(image, caption="Uploaded Image", use_column_width=True)

            bert_model = 'vinai/phobert-base-v2'
            output_dir = '/Model/MultimodelNER/VLSP2021/best_model'
            output_model_file = os.path.join(output_dir, WEIGHTS_NAME)
            output_encoder_file = os.path.join(output_dir, "pytorch_encoder.bin")
            processor = MNERProcessor_2021()
            label_list = processor.get_labels()
            auxlabel_list = processor.get_auxlabels()
            num_labels = len(label_list) + 1
            auxnum_labels = len(auxlabel_list) + 1
            trans_matrix = np.zeros((auxnum_labels, num_labels), dtype=float)
            trans_matrix[0, 0] = 1  # pad to pad
            trans_matrix[1, 1] = 1  # O to O
            trans_matrix[2, 2] = 0.25  # B to B-MISC
            trans_matrix[2, 4] = 0.25  # B to B-PER
            trans_matrix[2, 6] = 0.25  # B to B-ORG
            trans_matrix[2, 8] = 0.25  # B to B-LOC
            trans_matrix[3, 3] = 0.25  # I to I-MISC
            trans_matrix[3, 5] = 0.25  # I to I-PER
            trans_matrix[3, 7] = 0.25  # I to I-ORG
            trans_matrix[3, 9] = 0.25  # I to I-LOC
            trans_matrix[4, 10] = 1  # X to X
            trans_matrix[5, 11] = 1  # [CLS] to [CLS]
            trans_matrix[6, 12] = 1
            tokenizer = AutoTokenizer.from_pretrained(bert_model, do_lower_case=False)
            model_umt, encoder_umt = load_model(output_model_file, output_encoder_file, encoder, num_labels,
                                                auxnum_labels)
            eval_examples = get_test_examples_predict(
                '/Model/MultimodelNER/VLSP2021/Filetxt/')

            y_pred, a = predict(model_umt, encoder_umt, eval_examples, tokenizer, device, save_image, trans_matrix)
            formatted_output = format_predictions(a, y_pred[0])
            final = process_predictions(formatted_output)
            final2 = combine_entities(final)
            final3 = remove_B_prefix(final2)
            final4 = combine_i_tags(final3)

            words_and_labels = final4
            # Tạo danh sách từ
            words = [word for word, _ in words_and_labels]
            # Tạo danh sách thực thể và nhãn cho mỗi từ, loại bỏ nhãn 'O'
            entities = [{'start': sum(len(word) + 1 for word, _ in words_and_labels[:i]),
                         'end': sum(len(word) + 1 for word, _ in words_and_labels[:i + 1]), 'label': label} for
                        i, (word, label)
                        in enumerate(words_and_labels) if label != 'O']
            # print(entities)

            # Render the visualization without color for 'O' labels
            html = displacy.render(
                {"text": " ".join(words), "ents": entities, "title": None},
                style="ent",
                manual=True,
                options={"colors": {"DATETIME-DATERANGE": "#66c2ff",
                                        "LOCATION-GPE": "#ffcc99",
                                        "O": None,  # Màu cho nhãn 'O'
                                        "QUANTITY-NUM": "#ffdf80",
                                        "EVENT-CUL": "#bfbfbf",
                                        "DATETIME": "#80ff80",
                                        "PERSONTYPE": "#ff80ff",
                                        "PERSON": "#bf80ff",
                                        "QUANTITY-PER": "#80cccc",
                                        "ORGANIZATION": "#ff6666",
                                        "LOCATION-GEO": "#66cc66",
                                        "LOCATION-STRUC": "#cccc66",
                                        "PRODUCT-COM": "#ffff66",
                                        "DATETIME-DATE": "#66cccc",
                                        "QUANTITY-DIM": "#6666ff",
                                        "PRODUCT": "#cc6666",
                                        "QUANTITY": "#6666cc",
                                        "DATETIME-DURATION": "#9966ff",
                                        "QUANTITY-CUR": "#ff9966",
                                        "DATETIME-TIME": "#cdbf93",
                                        "QUANTITY-TEM": "#cc9966",
                                        "DATETIME-TIMERANGE": "#cc8566",
                                        "EVENT-GAMESHOW": "#8c8c5a",
                                        "QUANTITY-AGE": "#70db70",
                                        "QUANTITY-ORD": "#e699ff",
                                        "PRODUCT-LEGAL": "#806699",
                                        "LOCATION": "#993366",
                                        "ORGANIZATION-MED": "#339933",
                                        "URL": "#ff4d4d",
                                        "PHONENUMBER": "#99cc99",
                                        "ORGANIZATION-SPORTS": "#6666ff",
                                        "EVENT-SPORT": "#ffff80",
                                        "SKILL": "#b38f66",
                                        "EVENT-NATURAL": "#ff9966",
                                        "ADDRESS": "#cc9966",
                                        "IP": "#b38f66",
                                        "EMAIL": "#cc8566",
                                        "ORGANIZATION-STOCK": "#666633",
                                        "DATETIME-SET": "#70db70",
                                        "PRODUCT-AWARD": "#e699ff",
                                        "MISCELLANEOUS": "#806699",
                                        "LOCATION-GPE-GEO": "#99ffff"}}
            )
            # print(html)
            st.markdown(html, unsafe_allow_html=True)

        # Sử dụng widget st.html để hiển thị HTML

    # Hiển thị văn bản đã nhập
    # st.write("Văn bản đã nhập:", text)


###Ví dụ 1 : Một trận hỗn chiến đã xảy ra tại trận đấu khúc côn cầu giữa  Penguins và Islanders ở Mỹ (image:penguin)