Spaces:
Runtime error
Runtime error
Oliver12315
commited on
Commit
•
10fa1e9
1
Parent(s):
621f0bd
Upload core files
Browse files- .gitignore +3 -0
- Prediction.py +83 -0
- README.md +5 -5
- app.py +122 -4
- convert.py +30 -0
- requirements.txt +5 -0
.gitignore
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
/output/*
|
2 |
+
.vscode
|
3 |
+
__pycache__
|
Prediction.py
ADDED
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pandas as pd
|
2 |
+
from tqdm.auto import tqdm
|
3 |
+
import torch
|
4 |
+
from transformers import BertTokenizerFast as BertTokenizer, BertForSequenceClassification
|
5 |
+
import os
|
6 |
+
import glob
|
7 |
+
|
8 |
+
|
9 |
+
RANDOM_SEED = 42
|
10 |
+
pd.RANDOM_SEED = 42
|
11 |
+
LABEL_COLUMNS = ["Assertive Tone", "Conversational Tone", "Emotional Tone", "Informative Tone", "None"]
|
12 |
+
|
13 |
+
|
14 |
+
@torch.no_grad()
|
15 |
+
def predict_csv(data, text_col, tokenizer, model, device, text_bs=16, max_token_len=128):
|
16 |
+
predictions = []
|
17 |
+
post = data[text_col]
|
18 |
+
num_text = len(post)
|
19 |
+
generator = range(0, num_text, text_bs)
|
20 |
+
for i in tqdm(generator, total=len(generator), desc="Processing..."):
|
21 |
+
texts = post[i: min(num_text, i+text_bs)].tolist()
|
22 |
+
encoding = tokenizer(
|
23 |
+
texts,
|
24 |
+
add_special_tokens=True,
|
25 |
+
max_length=max_token_len,
|
26 |
+
return_token_type_ids=False,
|
27 |
+
padding="max_length",
|
28 |
+
truncation=True,
|
29 |
+
return_attention_mask=True,
|
30 |
+
return_tensors='pt',
|
31 |
+
)
|
32 |
+
logits = model(
|
33 |
+
encoding["input_ids"].to(device),
|
34 |
+
encoding["attention_mask"].to(device),
|
35 |
+
return_dict=True
|
36 |
+
).logits
|
37 |
+
prediction = torch.softmax(logits, dim=1)
|
38 |
+
predictions.append(prediction.detach().cpu())
|
39 |
+
|
40 |
+
final_pred = torch.cat(predictions, dim=0)
|
41 |
+
y_inten = final_pred.numpy().T
|
42 |
+
|
43 |
+
for i in range(len(LABEL_COLUMNS)):
|
44 |
+
data[LABEL_COLUMNS[i]] = [round(i, 8) for i in y_inten[i].tolist()]
|
45 |
+
return data
|
46 |
+
|
47 |
+
@torch.no_grad()
|
48 |
+
def predict_single(sentence, tokenizer, model, device, max_token_len=128):
|
49 |
+
encoding = tokenizer(
|
50 |
+
sentence,
|
51 |
+
add_special_tokens=True,
|
52 |
+
max_length=max_token_len,
|
53 |
+
return_token_type_ids=False,
|
54 |
+
padding="max_length",
|
55 |
+
truncation=True,
|
56 |
+
return_attention_mask=True,
|
57 |
+
return_tensors='pt',
|
58 |
+
)
|
59 |
+
logits = model(
|
60 |
+
encoding["input_ids"].to(device),
|
61 |
+
encoding["attention_mask"].to(device),
|
62 |
+
return_dict=True
|
63 |
+
).logits
|
64 |
+
prediction = torch.softmax(logits, dim=1)
|
65 |
+
y_inten = prediction.flatten().cpu().numpy().T.tolist()
|
66 |
+
y_inten = [round(i, 8) for i in y_inten]
|
67 |
+
return y_inten
|
68 |
+
|
69 |
+
|
70 |
+
|
71 |
+
if __name__ == "__main__":
|
72 |
+
|
73 |
+
Data = pd.read_csv("assets/Kickstarter_sentence_level_5000.csv")
|
74 |
+
Data = Data[:20]
|
75 |
+
device = torch.device('cpu')
|
76 |
+
|
77 |
+
# Load model directly
|
78 |
+
tokenizer = BertTokenizer.from_pretrained("Oliver12315/Brand_Tone_of_Voice")
|
79 |
+
model = BertForSequenceClassification.from_pretrained("Oliver12315/Brand_Tone_of_Voice")
|
80 |
+
model = model.to(device)
|
81 |
+
fk_doc_result = predict_csv(Data,"content", tokenizer, model, device)
|
82 |
+
single_response = predict_single("Games of the imagination teach us actions have consequences in a realm that can be reset.", tokenizer, model, device)
|
83 |
+
fk_doc_result.to_csv(f"output/prediction_Brand_Tone_of_Voice.csv")
|
README.md
CHANGED
@@ -1,10 +1,10 @@
|
|
1 |
---
|
2 |
-
title:
|
3 |
-
emoji:
|
4 |
-
colorFrom:
|
5 |
-
colorTo:
|
6 |
sdk: gradio
|
7 |
-
sdk_version: 4.
|
8 |
app_file: app.py
|
9 |
pinned: false
|
10 |
license: mit
|
|
|
1 |
---
|
2 |
+
title: Murphy
|
3 |
+
emoji: 📊
|
4 |
+
colorFrom: purple
|
5 |
+
colorTo: red
|
6 |
sdk: gradio
|
7 |
+
sdk_version: 4.10.0
|
8 |
app_file: app.py
|
9 |
pinned: false
|
10 |
license: mit
|
app.py
CHANGED
@@ -1,7 +1,125 @@
|
|
1 |
import gradio as gr
|
|
|
|
|
|
|
|
|
|
|
2 |
|
3 |
-
def greet(name):
|
4 |
-
return "Hello " + name + "!!"
|
5 |
|
6 |
-
|
7 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import gradio as gr
|
2 |
+
import pandas as pd
|
3 |
+
import matplotlib.pyplot as plt
|
4 |
+
from Prediction import *
|
5 |
+
import os
|
6 |
+
from datetime import datetime
|
7 |
|
|
|
|
|
8 |
|
9 |
+
examples = []
|
10 |
+
if os.path.exists("assets/examples.txt"):
|
11 |
+
with open("assets/examples.txt", "r", encoding="utf8") as file:
|
12 |
+
for sentence in file:
|
13 |
+
sentence = sentence.strip()
|
14 |
+
examples.append(sentence)
|
15 |
+
else:
|
16 |
+
examples = [
|
17 |
+
"Games of the imagination teach us actions have consequences in a realm that can be reset.",
|
18 |
+
"But New Jersey farmers are retiring and all over the state, development continues to push out dwindling farmland.",
|
19 |
+
"He also is the Head Designer of The Design Trust so-to-speak, besides his regular job ..."
|
20 |
+
]
|
21 |
+
|
22 |
+
device = torch.device('cpu')
|
23 |
+
tokenizer = BertTokenizer.from_pretrained("Oliver12315/Brand_Tone_of_Voice")
|
24 |
+
model = BertForSequenceClassification.from_pretrained("Oliver12315/Brand_Tone_of_Voice")
|
25 |
+
model = model.to(device)
|
26 |
+
|
27 |
+
|
28 |
+
def single_sentence(sentence):
|
29 |
+
predictions = predict_single(sentence, tokenizer, model, device)
|
30 |
+
predictions.sort(reverse=True)
|
31 |
+
return list(zip(LABEL_COLUMNS, predictions))
|
32 |
+
|
33 |
+
def csv_process(csv_file, attr="content"):
|
34 |
+
current_time = datetime.now()
|
35 |
+
formatted_time = current_time.strftime("%Y_%m_%d_%H_%M_%S")
|
36 |
+
data = pd.read_csv(csv_file.name)
|
37 |
+
data = data.reset_index()
|
38 |
+
os.makedirs('output', exist_ok=True)
|
39 |
+
outputs = []
|
40 |
+
predictions = predict_csv(data, attr, tokenizer, model, device)
|
41 |
+
output_path = f"output/prediction_Brand_Tone_of_Voice_{formatted_time}.csv"
|
42 |
+
predictions.to_csv(output_path)
|
43 |
+
outputs.append(output_path)
|
44 |
+
return outputs
|
45 |
+
|
46 |
+
|
47 |
+
my_theme = gr.Theme.from_hub("JohnSmith9982/small_and_pretty")
|
48 |
+
with gr.Blocks(theme=my_theme, title='Brand_Tone_of_Voice_demo') as demo:
|
49 |
+
gr.HTML(
|
50 |
+
"""
|
51 |
+
<div style="display: flex; justify-content: center; align-items: center; text-align: center;">
|
52 |
+
<a href="https://github.com/xxx" style="margin-right: 20px; text-decoration: none; display: flex; align-items: center;">
|
53 |
+
</a>
|
54 |
+
<div>
|
55 |
+
<h1 >Place the title of the paper here</h1>
|
56 |
+
<h5 style="margin: 0;">If you like our project, please give us a star ✨ on Github for the latest update.</h5>
|
57 |
+
<div style="display: flex; justify-content: center; align-items: center; text-align: center;>
|
58 |
+
<a href="https://arxiv.org/abs/xx.xx"><img src="https://img.shields.io/badge/Arxiv-xx.xx-red"></a>
|
59 |
+
<a href='https://huggingface.co/spaces/Oliver12315/Brand_Tone_of_Voice_demo'><img src='https://img.shields.io/badge/Project_Page-Oliver12315/Brand_Tone_of_Voice_demo' alt='Project Page'></a>
|
60 |
+
<a href='https://github.com'><img src='https://img.shields.io/badge/Github-Code-blue'></a>
|
61 |
+
</div>
|
62 |
+
</div>
|
63 |
+
</div>
|
64 |
+
""")
|
65 |
+
|
66 |
+
with gr.Tab("Single Sentence"):
|
67 |
+
with gr.Row():
|
68 |
+
tbox_input = gr.Textbox(label="Input",
|
69 |
+
info="Please input a sentence here:")
|
70 |
+
gr.Markdown("""
|
71 |
+
# Detailed information about our model:
|
72 |
+
...
|
73 |
+
""")
|
74 |
+
tab_output = gr.DataFrame(label='Predictions:',
|
75 |
+
headers=["Label", "Probability"],
|
76 |
+
datatype=["str", "number"],
|
77 |
+
interactive=False)
|
78 |
+
with gr.Row():
|
79 |
+
button_ss = gr.Button("Submit", variant="primary")
|
80 |
+
button_ss.click(fn=single_sentence, inputs=[tbox_input], outputs=[tab_output])
|
81 |
+
gr.ClearButton([tbox_input, tab_output])
|
82 |
+
|
83 |
+
gr.Examples(
|
84 |
+
examples=examples,
|
85 |
+
inputs=tbox_input,
|
86 |
+
examples_per_page=len(examples)
|
87 |
+
)
|
88 |
+
|
89 |
+
with gr.Tab("Csv File"):
|
90 |
+
with gr.Row():
|
91 |
+
csv_input = gr.File(label="CSV File:",
|
92 |
+
file_types=['.csv'],
|
93 |
+
file_count="single"
|
94 |
+
)
|
95 |
+
csv_output = gr.File(label="Predictions:")
|
96 |
+
|
97 |
+
with gr.Row():
|
98 |
+
button = gr.Button("Submit", variant="primary")
|
99 |
+
button.click(fn=csv_process, inputs=[csv_input], outputs=[csv_output])
|
100 |
+
gr.ClearButton([csv_input, csv_output])
|
101 |
+
|
102 |
+
gr.Markdown("## Examples \n The incoming CSV must include the ``content`` field, which represents the text that needs to be predicted!")
|
103 |
+
gr.DataFrame(label='Csv input format:',
|
104 |
+
value=[[i, examples[i]] for i in range(len(examples))],
|
105 |
+
headers=["index", "content"],
|
106 |
+
datatype=["number","str"],
|
107 |
+
interactive=False
|
108 |
+
)
|
109 |
+
|
110 |
+
with gr.Tab("Readme"):
|
111 |
+
gr.Markdown(
|
112 |
+
"""
|
113 |
+
# Paper Name
|
114 |
+
|
115 |
+
# Authors
|
116 |
+
|
117 |
+
+ First author
|
118 |
+
+ Corresponding author
|
119 |
+
|
120 |
+
# Detailed Information
|
121 |
+
|
122 |
+
...
|
123 |
+
"""
|
124 |
+
)
|
125 |
+
demo.launch()
|
convert.py
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import glob
|
3 |
+
import os
|
4 |
+
from transformers import BertTokenizerFast as BertTokenizer, BertForSequenceClassification
|
5 |
+
|
6 |
+
os.environ['https_proxy'] = "127.0.0.1:1081"
|
7 |
+
|
8 |
+
LABEL_COLUMNS = ["Assertive Tone", "Conversational Tone", "Emotional Tone", "Informative Tone", "None"]
|
9 |
+
|
10 |
+
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
|
11 |
+
model = BertForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=5)
|
12 |
+
id2label = {i:label for i,label in enumerate(LABEL_COLUMNS)}
|
13 |
+
label2id = {label:i for i,label in enumerate(LABEL_COLUMNS)}
|
14 |
+
|
15 |
+
for ckpt in glob.glob('checkpoints/*.ckpt'):
|
16 |
+
base_name = os.path.basename(ckpt)
|
17 |
+
# 去除文件后缀
|
18 |
+
model_name = os.path.splitext(base_name)[0]
|
19 |
+
params = torch.load(ckpt, map_location="cpu")['state_dict']
|
20 |
+
msg = model.load_state_dict(params, strict=True)
|
21 |
+
path = f'models/{model_name}'
|
22 |
+
os.makedirs(path, exist_ok=True)
|
23 |
+
|
24 |
+
torch.save(model.state_dict(), f'{path}/pytorch_model.bin')
|
25 |
+
config = model.config
|
26 |
+
config.architectures = ['BertForSequenceClassification']
|
27 |
+
config.label2id = label2id
|
28 |
+
config.id2label = id2label
|
29 |
+
model.config.to_json_file(f'{path}/config.json')
|
30 |
+
tokenizer.save_vocabulary(path)
|
requirements.txt
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch
|
2 |
+
transformers
|
3 |
+
tqdm
|
4 |
+
pandas
|
5 |
+
datetime
|