File size: 9,853 Bytes
621f0bd
10fa1e9
 
 
 
d0ebf51
 
2b50e79
d0ebf51
 
 
 
2b50e79
 
 
 
 
 
d0ebf51
 
 
 
 
 
 
 
 
 
 
 
 
10fa1e9
 
 
 
 
 
 
 
 
4b30813
10fa1e9
 
 
 
 
 
 
 
 
 
 
 
 
 
2b50e79
 
 
 
 
 
 
d0ebf51
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10fa1e9
 
 
 
 
 
 
 
 
 
 
 
 
d0ebf51
10fa1e9
 
 
 
 
 
d0ebf51
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2b50e79
d0ebf51
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2b50e79
d0ebf51
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10fa1e9
d0ebf51
 
 
2b50e79
 
 
 
 
 
 
 
 
 
 
d0ebf51
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2b50e79
d0ebf51
 
 
 
 
2b50e79
d0ebf51
10fa1e9
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
import gradio as gr
import pandas as pd
from Prediction import *
import os
from datetime import datetime
import re
import json
import hashlib

persistent_path = "/data"
os.environ['HF_HOME'] = os.path.join(persistent_path, ".huggingface")
user_input_path = os.path.join(persistent_path, 'user.jsonl')
secret = "2fc9ff032e027e8f23bb9fb693234899"

def get_md5(s):
    md = hashlib.md5()
    md.update(s.encode('utf-8'))
    return md.hexdigest()

examples = []
if os.path.exists("assets/examples.txt"):
    with open("assets/examples.txt", "r", encoding="utf8") as file:
        for sentence in file:
            sentence = sentence.strip()
            examples.append(sentence)
else:
    examples = [
        "Games of the imagination teach us actions have consequences in a realm that can be reset.",
        "But New Jersey farmers are retiring and all over the state, development continues to push out dwindling farmland.",
        "He also is the Head Designer of The Design Trust so-to-speak, besides his regular job ..."
        ]

device = torch.device('cpu')
tokenizer = BertTokenizer.from_pretrained("Oliver12315/Brand_Tone_of_Voice")
model = BertForSequenceClassification.from_pretrained("Oliver12315/Brand_Tone_of_Voice")
model = model.to(device)


def single_sentence(sentence):
    predictions = predict_single(sentence, tokenizer, model, device)
    return sorted(zip(LABEL_COLUMNS, predictions), key=lambda x:x[-1], reverse=True)

def csv_process(csv_file, attr="content"):
    current_time = datetime.now()
    formatted_time = current_time.strftime("%Y_%m_%d_%H_%M_%S")
    data = pd.read_csv(csv_file.name)
    data = data.reset_index()
    os.makedirs('output', exist_ok=True)
    outputs = []
    predictions = predict_csv(data, attr, tokenizer, model, device)
    output_path = f"output/prediction_Brand_Tone_of_Voice_{formatted_time}.csv"
    predictions.to_csv(output_path)
    outputs.append(output_path)
    return outputs

def logfile_query(auth):
    if get_md5(auth) == secret and os.path.exists(user_input_path):
        return [user_input_path]
    else:
        return None


def check_save(fname, lname, cnum, email, oname, position):
    errors = []
    valid_vars = {}

    if not fname.strip() or not lname.strip():
        errors.append("Name cannot be empty")
    elif fname.isdigit() or lname.isdigit():
        errors.append("Name cannot be purely numerical")
    else:
        valid_vars["fname"] = fname
        valid_vars["lname"] = lname

    valid_vars["cnum"] = ''
    if cnum:
        if not cnum.isdigit():
            errors.append("The phone number must be a pure number")
        else:
            valid_vars["cnum"] = cnum

    if not email.strip():
        errors.append("Email cannot be empty")
    elif not re.match(r'^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$', email):
        errors.append("Incorrect email format")
    else:
        valid_vars["email"] = email

    if not oname.strip():
        errors.append("Organization name cannot be empty")
    elif oname.isdigit():
        errors.append("Organization cannot be purely numerical")
    else:
        valid_vars["oname"] = oname

    valid_vars["position"] = ''
    if position:
        if position.isdigit():
            errors.append("Position in your company cannot be purely numerical")
        else:
            valid_vars["position"] = position

    if errors:
        return errors
    
    current_time = datetime.now()
    formatted_time = current_time.strftime("%Y_%m_%d_%H_%M_%S")
    valid_vars['time'] = formatted_time

    with open(user_input_path, 'a+', encoding="utf8") as file:
        file.write(json.dumps(valid_vars)+"\n")

    records = {}
    with open(user_input_path, 'r', encoding="utf8") as file:
        for line in file:
            line = line.strip()
            dct = json.loads(line)
            records[dct['time']] = dct

    return records


my_theme = gr.Theme.from_hub("JohnSmith9982/small_and_pretty")
with gr.Blocks(theme=my_theme, title='Brand_Tone_of_Voice_demo') as demo:
    gr.HTML(
        """
        <div style="display: flex; justify-content: center; align-items: center; text-align: center;">
        <a href="https://github.com/xxx" style="margin-right: 20px; text-decoration: none; display: flex; align-items: center;">
        </a>
        <div>
            <h1 >Place the title of the paper here</h1>
            <h5 style="margin: 0;">If you like our project, please give us a star ✨ on Github for the latest update.</h5>
            <div style="display: flex; justify-content: center; align-items: center; text-align: center;>
                <a href="https://arxiv.org/abs/xx.xx"><img src="https://img.shields.io/badge/Arxiv-xx.xx-red"></a>
                <a href='https://huggingface.co/spaces/Oliver12315/Brand_Tone_of_Voice_demo'><img src='https://img.shields.io/badge/Project_Page-Oliver12315/Brand_Tone_of_Voice_demo' alt='Project Page'></a>
                <a href='https://github.com'><img src='https://img.shields.io/badge/Github-Code-blue'></a>
            </div>
        </div>
        </div>
        """)

    with gr.Column(visible=True) as regis:
        gr.Markdown("# Welcome to BTV!  Please fill out the form below to continue.\nI’m assuming that you mention somewhere that this project/research is conducted by the University of Manchester/AMBS. By ticking this box, I consent to be approached by the research team of the University of Manchester.")
        with gr.Column(variant='panel'):   
            fname_tb = gr.Textbox(label="First Name: ", type='text')
            lname_tb = gr.Textbox(label="Last Name: ", type='text')
            email_tb = gr.Textbox(label="Email: ", type='email')
            cnum_tb = gr.Textbox(label="Contact: (Optional)", type='text')
            oname_tb = gr.Textbox(label="Organization name: ", type='text')
            position_tb = gr.Textbox(label="Positions in your company: (Optional)", type='text')
        error_box = gr.HTML(value="", visible=False)
        submit_btn = gr.Button("Click here to start if you have fullfill all the item!")

    with gr.Row(visible=False) as mainrow:

        with gr.Tab("Single Sentence"):
            with gr.Row():
                tbox_input = gr.Textbox(label="Input",
                                        info="Please input a sentence here:")
                gr.Markdown("""
                    # Detailed information about our model:
                    ...
                    """)
            tab_output = gr.DataFrame(label='Predictions:', 
                                    headers=["Label", "Probability"],
                                    datatype=["str", "number"],
                                    interactive=False)
            with gr.Row():
                button_ss = gr.Button("Submit", variant="primary")
                button_ss.click(fn=single_sentence, inputs=[tbox_input], outputs=[tab_output])
                gr.ClearButton([tbox_input, tab_output])

            gr.Examples(
                examples=examples,
                inputs=tbox_input,
                examples_per_page=len(examples)
            )

        with gr.Tab("Csv File"):
            with gr.Row():
                csv_input = gr.File(label="CSV File:",
                                    file_types=['.csv'],
                                    file_count="single"
                                    )
                csv_output = gr.File(label="Predictions:")

            with gr.Row():
                button_cf = gr.Button("Submit", variant="primary")
                button_cf.click(fn=csv_process, inputs=[csv_input], outputs=[csv_output])
                gr.ClearButton([csv_input, csv_output])

            gr.Markdown("## Examples \n The incoming CSV must include the ``content`` field, which represents the text that needs to be predicted!")
            gr.DataFrame(label='Csv input format:',
                        value=[[i, examples[i]] for i in range(len(examples))],
                        headers=["index", "content"],
                        datatype=["number","str"],
                        interactive=False
                        )

        with gr.Tab("Readme"):
            gr.Markdown(
                """
                # Paper Name

                # Authors

                + First author
                + Corresponding author
                
                # Detailed Information

                ...
                """
            )

        with gr.Tab("Log File"):
            with gr.Row():
                auth_token = gr.Textbox(label="Authentication Tokens: ", info="Enter the key to download persistent stored log information.")
                log_output = gr.File(label="Log file: ")

            with gr.Row():
                button_lf = gr.Button("Validate", variant="primary")
                button_lf.click(fn=logfile_query, inputs=[auth_token], outputs=[log_output])
                gr.ClearButton([auth_token, log_output])


    def submit(*user_input):
        res = check_save(*user_input)
        if isinstance(res, list):
            return {
                error_box: gr.HTML(
                    value=f"""
                    <div style="display: flex; justify-content: center; align-items: center; text-align: center;">
                    <div>
                        <p style="color:red;">{"; ".join(res)}</p>
                    </div>
                    </div>
                    """, 
                    visible=True)
            }
        else:
            return {
                mainrow: gr.Row(visible=True),
                regis: gr.Row(visible=False),
                error_box: gr.HTML(visible=False)
            }

    submit_btn.click(
        submit,
        [fname_tb, lname_tb, cnum_tb, email_tb, oname_tb, position_tb],
        [mainrow, regis, error_box],
    )
demo.launch()