Oliver12315 commited on
Commit
10fa1e9
1 Parent(s): 621f0bd

Upload core files

Browse files
Files changed (6) hide show
  1. .gitignore +3 -0
  2. Prediction.py +83 -0
  3. README.md +5 -5
  4. app.py +122 -4
  5. convert.py +30 -0
  6. requirements.txt +5 -0
.gitignore ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ /output/*
2
+ .vscode
3
+ __pycache__
Prediction.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ from tqdm.auto import tqdm
3
+ import torch
4
+ from transformers import BertTokenizerFast as BertTokenizer, BertForSequenceClassification
5
+ import os
6
+ import glob
7
+
8
+
9
+ RANDOM_SEED = 42
10
+ pd.RANDOM_SEED = 42
11
+ LABEL_COLUMNS = ["Assertive Tone", "Conversational Tone", "Emotional Tone", "Informative Tone", "None"]
12
+
13
+
14
+ @torch.no_grad()
15
+ def predict_csv(data, text_col, tokenizer, model, device, text_bs=16, max_token_len=128):
16
+ predictions = []
17
+ post = data[text_col]
18
+ num_text = len(post)
19
+ generator = range(0, num_text, text_bs)
20
+ for i in tqdm(generator, total=len(generator), desc="Processing..."):
21
+ texts = post[i: min(num_text, i+text_bs)].tolist()
22
+ encoding = tokenizer(
23
+ texts,
24
+ add_special_tokens=True,
25
+ max_length=max_token_len,
26
+ return_token_type_ids=False,
27
+ padding="max_length",
28
+ truncation=True,
29
+ return_attention_mask=True,
30
+ return_tensors='pt',
31
+ )
32
+ logits = model(
33
+ encoding["input_ids"].to(device),
34
+ encoding["attention_mask"].to(device),
35
+ return_dict=True
36
+ ).logits
37
+ prediction = torch.softmax(logits, dim=1)
38
+ predictions.append(prediction.detach().cpu())
39
+
40
+ final_pred = torch.cat(predictions, dim=0)
41
+ y_inten = final_pred.numpy().T
42
+
43
+ for i in range(len(LABEL_COLUMNS)):
44
+ data[LABEL_COLUMNS[i]] = [round(i, 8) for i in y_inten[i].tolist()]
45
+ return data
46
+
47
+ @torch.no_grad()
48
+ def predict_single(sentence, tokenizer, model, device, max_token_len=128):
49
+ encoding = tokenizer(
50
+ sentence,
51
+ add_special_tokens=True,
52
+ max_length=max_token_len,
53
+ return_token_type_ids=False,
54
+ padding="max_length",
55
+ truncation=True,
56
+ return_attention_mask=True,
57
+ return_tensors='pt',
58
+ )
59
+ logits = model(
60
+ encoding["input_ids"].to(device),
61
+ encoding["attention_mask"].to(device),
62
+ return_dict=True
63
+ ).logits
64
+ prediction = torch.softmax(logits, dim=1)
65
+ y_inten = prediction.flatten().cpu().numpy().T.tolist()
66
+ y_inten = [round(i, 8) for i in y_inten]
67
+ return y_inten
68
+
69
+
70
+
71
+ if __name__ == "__main__":
72
+
73
+ Data = pd.read_csv("assets/Kickstarter_sentence_level_5000.csv")
74
+ Data = Data[:20]
75
+ device = torch.device('cpu')
76
+
77
+ # Load model directly
78
+ tokenizer = BertTokenizer.from_pretrained("Oliver12315/Brand_Tone_of_Voice")
79
+ model = BertForSequenceClassification.from_pretrained("Oliver12315/Brand_Tone_of_Voice")
80
+ model = model.to(device)
81
+ fk_doc_result = predict_csv(Data,"content", tokenizer, model, device)
82
+ single_response = predict_single("Games of the imagination teach us actions have consequences in a realm that can be reset.", tokenizer, model, device)
83
+ fk_doc_result.to_csv(f"output/prediction_Brand_Tone_of_Voice.csv")
README.md CHANGED
@@ -1,10 +1,10 @@
1
  ---
2
- title: Brand Tone Of Voice Online Demo
3
- emoji: 🐠
4
- colorFrom: gray
5
- colorTo: pink
6
  sdk: gradio
7
- sdk_version: 4.12.0
8
  app_file: app.py
9
  pinned: false
10
  license: mit
 
1
  ---
2
+ title: Murphy
3
+ emoji: 📊
4
+ colorFrom: purple
5
+ colorTo: red
6
  sdk: gradio
7
+ sdk_version: 4.10.0
8
  app_file: app.py
9
  pinned: false
10
  license: mit
app.py CHANGED
@@ -1,7 +1,125 @@
1
  import gradio as gr
 
 
 
 
 
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
5
 
6
- iface = gr.Interface(fn=greet, inputs="text", outputs="text")
7
- iface.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
+ import pandas as pd
3
+ import matplotlib.pyplot as plt
4
+ from Prediction import *
5
+ import os
6
+ from datetime import datetime
7
 
 
 
8
 
9
+ examples = []
10
+ if os.path.exists("assets/examples.txt"):
11
+ with open("assets/examples.txt", "r", encoding="utf8") as file:
12
+ for sentence in file:
13
+ sentence = sentence.strip()
14
+ examples.append(sentence)
15
+ else:
16
+ examples = [
17
+ "Games of the imagination teach us actions have consequences in a realm that can be reset.",
18
+ "But New Jersey farmers are retiring and all over the state, development continues to push out dwindling farmland.",
19
+ "He also is the Head Designer of The Design Trust so-to-speak, besides his regular job ..."
20
+ ]
21
+
22
+ device = torch.device('cpu')
23
+ tokenizer = BertTokenizer.from_pretrained("Oliver12315/Brand_Tone_of_Voice")
24
+ model = BertForSequenceClassification.from_pretrained("Oliver12315/Brand_Tone_of_Voice")
25
+ model = model.to(device)
26
+
27
+
28
+ def single_sentence(sentence):
29
+ predictions = predict_single(sentence, tokenizer, model, device)
30
+ predictions.sort(reverse=True)
31
+ return list(zip(LABEL_COLUMNS, predictions))
32
+
33
+ def csv_process(csv_file, attr="content"):
34
+ current_time = datetime.now()
35
+ formatted_time = current_time.strftime("%Y_%m_%d_%H_%M_%S")
36
+ data = pd.read_csv(csv_file.name)
37
+ data = data.reset_index()
38
+ os.makedirs('output', exist_ok=True)
39
+ outputs = []
40
+ predictions = predict_csv(data, attr, tokenizer, model, device)
41
+ output_path = f"output/prediction_Brand_Tone_of_Voice_{formatted_time}.csv"
42
+ predictions.to_csv(output_path)
43
+ outputs.append(output_path)
44
+ return outputs
45
+
46
+
47
+ my_theme = gr.Theme.from_hub("JohnSmith9982/small_and_pretty")
48
+ with gr.Blocks(theme=my_theme, title='Brand_Tone_of_Voice_demo') as demo:
49
+ gr.HTML(
50
+ """
51
+ <div style="display: flex; justify-content: center; align-items: center; text-align: center;">
52
+ <a href="https://github.com/xxx" style="margin-right: 20px; text-decoration: none; display: flex; align-items: center;">
53
+ </a>
54
+ <div>
55
+ <h1 >Place the title of the paper here</h1>
56
+ <h5 style="margin: 0;">If you like our project, please give us a star ✨ on Github for the latest update.</h5>
57
+ <div style="display: flex; justify-content: center; align-items: center; text-align: center;>
58
+ <a href="https://arxiv.org/abs/xx.xx"><img src="https://img.shields.io/badge/Arxiv-xx.xx-red"></a>
59
+ <a href='https://huggingface.co/spaces/Oliver12315/Brand_Tone_of_Voice_demo'><img src='https://img.shields.io/badge/Project_Page-Oliver12315/Brand_Tone_of_Voice_demo' alt='Project Page'></a>
60
+ <a href='https://github.com'><img src='https://img.shields.io/badge/Github-Code-blue'></a>
61
+ </div>
62
+ </div>
63
+ </div>
64
+ """)
65
+
66
+ with gr.Tab("Single Sentence"):
67
+ with gr.Row():
68
+ tbox_input = gr.Textbox(label="Input",
69
+ info="Please input a sentence here:")
70
+ gr.Markdown("""
71
+ # Detailed information about our model:
72
+ ...
73
+ """)
74
+ tab_output = gr.DataFrame(label='Predictions:',
75
+ headers=["Label", "Probability"],
76
+ datatype=["str", "number"],
77
+ interactive=False)
78
+ with gr.Row():
79
+ button_ss = gr.Button("Submit", variant="primary")
80
+ button_ss.click(fn=single_sentence, inputs=[tbox_input], outputs=[tab_output])
81
+ gr.ClearButton([tbox_input, tab_output])
82
+
83
+ gr.Examples(
84
+ examples=examples,
85
+ inputs=tbox_input,
86
+ examples_per_page=len(examples)
87
+ )
88
+
89
+ with gr.Tab("Csv File"):
90
+ with gr.Row():
91
+ csv_input = gr.File(label="CSV File:",
92
+ file_types=['.csv'],
93
+ file_count="single"
94
+ )
95
+ csv_output = gr.File(label="Predictions:")
96
+
97
+ with gr.Row():
98
+ button = gr.Button("Submit", variant="primary")
99
+ button.click(fn=csv_process, inputs=[csv_input], outputs=[csv_output])
100
+ gr.ClearButton([csv_input, csv_output])
101
+
102
+ gr.Markdown("## Examples \n The incoming CSV must include the ``content`` field, which represents the text that needs to be predicted!")
103
+ gr.DataFrame(label='Csv input format:',
104
+ value=[[i, examples[i]] for i in range(len(examples))],
105
+ headers=["index", "content"],
106
+ datatype=["number","str"],
107
+ interactive=False
108
+ )
109
+
110
+ with gr.Tab("Readme"):
111
+ gr.Markdown(
112
+ """
113
+ # Paper Name
114
+
115
+ # Authors
116
+
117
+ + First author
118
+ + Corresponding author
119
+
120
+ # Detailed Information
121
+
122
+ ...
123
+ """
124
+ )
125
+ demo.launch()
convert.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import glob
3
+ import os
4
+ from transformers import BertTokenizerFast as BertTokenizer, BertForSequenceClassification
5
+
6
+ os.environ['https_proxy'] = "127.0.0.1:1081"
7
+
8
+ LABEL_COLUMNS = ["Assertive Tone", "Conversational Tone", "Emotional Tone", "Informative Tone", "None"]
9
+
10
+ tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
11
+ model = BertForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=5)
12
+ id2label = {i:label for i,label in enumerate(LABEL_COLUMNS)}
13
+ label2id = {label:i for i,label in enumerate(LABEL_COLUMNS)}
14
+
15
+ for ckpt in glob.glob('checkpoints/*.ckpt'):
16
+ base_name = os.path.basename(ckpt)
17
+ # 去除文件后缀
18
+ model_name = os.path.splitext(base_name)[0]
19
+ params = torch.load(ckpt, map_location="cpu")['state_dict']
20
+ msg = model.load_state_dict(params, strict=True)
21
+ path = f'models/{model_name}'
22
+ os.makedirs(path, exist_ok=True)
23
+
24
+ torch.save(model.state_dict(), f'{path}/pytorch_model.bin')
25
+ config = model.config
26
+ config.architectures = ['BertForSequenceClassification']
27
+ config.label2id = label2id
28
+ config.id2label = id2label
29
+ model.config.to_json_file(f'{path}/config.json')
30
+ tokenizer.save_vocabulary(path)
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ torch
2
+ transformers
3
+ tqdm
4
+ pandas
5
+ datetime