wldmr commited on
Commit
42a2568
·
1 Parent(s): 9b150b6
Files changed (6) hide show
  1. app.py +94 -4
  2. myrpunct/__init__.py +2 -0
  3. myrpunct/punctuate.py +174 -0
  4. myrpunct/utils.py +34 -0
  5. requirements.txt +5 -0
  6. sample.srt +20 -0
app.py CHANGED
@@ -1,7 +1,97 @@
 
 
1
  import gradio as gr
 
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
 
 
 
 
 
 
 
5
 
6
- iface = gr.Interface(fn=greet, inputs="text", outputs="text")
7
- iface.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from myrpunct import RestorePuncts
2
+ from youtube_transcript_api import YouTubeTranscriptApi
3
  import gradio as gr
4
+ import re
5
 
6
+ def get_srt(input_link):
7
+ if "v=" in input_link:
8
+ video_id = input_link.split("v=")[1]
9
+ else:
10
+ return "Error: Invalid Link, it does not have the pattern 'v=' in it."
11
+ print("video_id: ",video_id)
12
+ transcript_raw = YouTubeTranscriptApi.get_transcript(video_id)
13
+ transcript_text= '\n'.join([i['text'] for i in transcript_raw])
14
+ return transcript_text
15
 
16
+ def predict(input_text, input_file, input_link, input_checkbox):
17
+
18
+ if input_checkbox=="File" and input_file is not None:
19
+ print("Input File ...")
20
+ with open(input_file.name) as file:
21
+ input_file_read = file.read()
22
+ return run_predict(input_file_read)
23
+ elif input_checkbox=="Text" and len(input_text) >0:
24
+ print("Input Text ...")
25
+ return run_predict(input_text)
26
+ elif input_checkbox=="Link" and len(input_link)>0:
27
+ print("Input Link ...", input_link)
28
+ input_link_text = get_srt(input_link)
29
+ if "Error" in input_link_text:
30
+ return input_link_text
31
+ else:
32
+ return run_predict(input_link_text)
33
+ else:
34
+ return "Error: Please provide either an input text or file and select an option accordingly."
35
+
36
+ def run_predict(input_text):
37
+ rpunct = RestorePuncts()
38
+ output_text = rpunct.punctuate(input_text)
39
+ print("Punctuation finished...")
40
+
41
+ # restore the carrige returns
42
+ srt_file = input_text
43
+ punctuated = output_text
44
+
45
+ srt_file_strip=srt_file.strip()
46
+ srt_file_sub=re.sub('\s*\n\s*','# ',srt_file_strip)
47
+ srt_file_array=srt_file_sub.split(' ')
48
+ pcnt_file_array=punctuated.split(' ')
49
+
50
+ # goal: restore the break points i.e. the same number of lines as the srt file
51
+ # this is necessary, because each line in the srt file corresponds to a frame from the video
52
+ if len(srt_file_array)!=len(pcnt_file_array):
53
+ return "AssertError: The length of the transcript and the punctuated file should be the same: ",len(srt_file_array),len(pcnt_file_array)
54
+ pcnt_file_array_hash = []
55
+ for idx, item in enumerate(srt_file_array):
56
+ if item.endswith('#'):
57
+ pcnt_file_array_hash.append(pcnt_file_array[idx]+'#')
58
+ else:
59
+ pcnt_file_array_hash.append(pcnt_file_array[idx])
60
+
61
+ # assemble the array back to a string
62
+ pcnt_file_cr=' '.join(pcnt_file_array_hash).replace('#','\n')
63
+
64
+ return pcnt_file_cr
65
+
66
+ if __name__ == "__main__":
67
+
68
+ title = "Rpunct Gradio App"
69
+ description = """
70
+ <b>Description</b>: <br>
71
+ Model restores punctuation and case i.e. of the following punctuations -- [! ? . , - : ; ' ] and also the upper-casing of words. <br>
72
+ <b>Usage</b>: <br>
73
+ There are three input types any text, a file that can be uploaded or a YouTube video. <br>
74
+ Because all three options can be provided by the user (that is you) at the same time <br>
75
+ the user has to decisde which input type has to be processed.
76
+ """
77
+ article = "Model by [felflare](https://huggingface.co/felflare/bert-restore-punctuation)"
78
+
79
+ sample_link = "https://www.youtube.com/watch?v=6MI0f6YjJIk"
80
+
81
+ examples = [["my name is clara and i live in berkeley california", "sample.srt", sample_link, "Text"]]
82
+
83
+ interface = gr.Interface(fn = predict,
84
+ inputs = ["text", "file", "text", gr.Radio(["Text", "File", "Link"], type="value", label='Input Type')],
85
+ outputs = ["text"],
86
+ title = title,
87
+ description = description,
88
+ article = article,
89
+ examples=examples,
90
+ allow_flagging="never")
91
+
92
+ interface.launch()
93
+
94
+ # save flagging to a hf dataset
95
+ # https://github.com/gradio-app/gradio/issues/914
96
+ # the best option here is to use a Hugging Face dataset as the storage for flagged data. And to do that, please check out the HuggingFaceDatasetSaver() flagging handler, which allows you to do that easily.
97
+ #Here is an example Space that uses this: https://huggingface.co/spaces/abidlabs/crowd-speech
myrpunct/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .punctuate import RestorePuncts
2
+ print("init executed ...")
myrpunct/punctuate.py ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # 💾⚙️🔮
3
+
4
+ __author__ = "Daulet N."
5
+ __email__ = "[email protected]"
6
+
7
+ import logging
8
+ from langdetect import detect
9
+ from simpletransformers.ner import NERModel, NERArgs
10
+
11
+
12
+ class RestorePuncts:
13
+ def __init__(self, wrds_per_pred=250, use_cuda=False):
14
+ self.wrds_per_pred = wrds_per_pred
15
+ self.overlap_wrds = 30
16
+ self.valid_labels = ['OU', 'OO', '.O', '!O', ',O', '.U', '!U', ',U', ':O', ';O', ':U', "'O", '-O', '?O', '?U']
17
+ self.model_hf = "wldmr/felflare-bert-restore-punctuation"
18
+ self.model_args = NERArgs()
19
+ self.model_args.silent = True
20
+ self.model_args.max_seq_length = 512
21
+ #self.model_args.use_multiprocessing = False
22
+ self.model = NERModel("bert", self.model_hf, labels=self.valid_labels, use_cuda=use_cuda, args=self.model_args)
23
+ #self.model = NERModel("bert", self.model_hf, labels=self.valid_labels, use_cuda=use_cuda, args={"silent": True, "max_seq_length": 512, "use_multiprocessing": False})
24
+ print("class init ...")
25
+ print("use_multiprocessing: ",self.model_args.use_multiprocessing)
26
+
27
+ def status(self):
28
+ print("function called")
29
+
30
+ def punctuate(self, text: str, lang:str=''):
31
+ """
32
+ Performs punctuation restoration on arbitrarily large text.
33
+ Detects if input is not English, if non-English was detected terminates predictions.
34
+ Overrride by supplying `lang='en'`
35
+
36
+ Args:
37
+ - text (str): Text to punctuate, can be few words to as large as you want.
38
+ - lang (str): Explicit language of input text.
39
+ """
40
+ if not lang and len(text) > 10:
41
+ lang = detect(text)
42
+ if lang != 'en':
43
+ raise Exception(F"""Non English text detected. Restore Punctuation works only for English.
44
+ If you are certain the input is English, pass argument lang='en' to this function.
45
+ Punctuate received: {text}""")
46
+
47
+ # plit up large text into bert digestable chunks
48
+ splits = self.split_on_toks(text, self.wrds_per_pred, self.overlap_wrds)
49
+ # predict slices
50
+ # full_preds_lst contains tuple of labels and logits
51
+ full_preds_lst = [self.predict(i['text']) for i in splits]
52
+ # extract predictions, and discard logits
53
+ preds_lst = [i[0][0] for i in full_preds_lst]
54
+ # join text slices
55
+ combined_preds = self.combine_results(text, preds_lst)
56
+ # create punctuated prediction
57
+ punct_text = self.punctuate_texts(combined_preds)
58
+ return punct_text
59
+
60
+ def predict(self, input_slice):
61
+ """
62
+ Passes the unpunctuated text to the model for punctuation.
63
+ """
64
+ predictions, raw_outputs = self.model.predict([input_slice])
65
+ return predictions, raw_outputs
66
+
67
+ @staticmethod
68
+ def split_on_toks(text, length, overlap):
69
+ """
70
+ Splits text into predefined slices of overlapping text with indexes (offsets)
71
+ that tie-back to original text.
72
+ This is done to bypass 512 token limit on transformer models by sequentially
73
+ feeding chunks of < 512 toks.
74
+ Example output:
75
+ [{...}, {"text": "...", 'start_idx': 31354, 'end_idx': 32648}, {...}]
76
+ """
77
+ wrds = text.replace('\n', ' ').split(" ")
78
+ resp = []
79
+ lst_chunk_idx = 0
80
+ i = 0
81
+
82
+ while True:
83
+ # words in the chunk and the overlapping portion
84
+ wrds_len = wrds[(length * i):(length * (i + 1))]
85
+ wrds_ovlp = wrds[(length * (i + 1)):((length * (i + 1)) + overlap)]
86
+ wrds_split = wrds_len + wrds_ovlp
87
+
88
+ # Break loop if no more words
89
+ if not wrds_split:
90
+ break
91
+
92
+ wrds_str = " ".join(wrds_split)
93
+ nxt_chunk_start_idx = len(" ".join(wrds_len))
94
+ lst_char_idx = len(" ".join(wrds_split))
95
+
96
+ resp_obj = {
97
+ "text": wrds_str,
98
+ "start_idx": lst_chunk_idx,
99
+ "end_idx": lst_char_idx + lst_chunk_idx,
100
+ }
101
+
102
+ resp.append(resp_obj)
103
+ lst_chunk_idx += nxt_chunk_start_idx + 1
104
+ i += 1
105
+ logging.info(f"Sliced transcript into {len(resp)} slices.")
106
+ return resp
107
+
108
+ @staticmethod
109
+ def combine_results(full_text: str, text_slices):
110
+ """
111
+ Given a full text and predictions of each slice combines predictions into a single text again.
112
+ Performs validataion wether text was combined correctly
113
+ """
114
+ split_full_text = full_text.replace('\n', ' ').split(" ")
115
+ split_full_text = [i for i in split_full_text if i]
116
+ split_full_text_len = len(split_full_text)
117
+ output_text = []
118
+ index = 0
119
+
120
+ if len(text_slices[-1]) <= 3 and len(text_slices) > 1:
121
+ text_slices = text_slices[:-1]
122
+
123
+ for _slice in text_slices:
124
+ slice_wrds = len(_slice)
125
+ for ix, wrd in enumerate(_slice):
126
+ # print(index, "|", str(list(wrd.keys())[0]), "|", split_full_text[index])
127
+ if index == split_full_text_len:
128
+ break
129
+
130
+ if split_full_text[index] == str(list(wrd.keys())[0]) and \
131
+ ix <= slice_wrds - 3 and text_slices[-1] != _slice:
132
+ index += 1
133
+ pred_item_tuple = list(wrd.items())[0]
134
+ output_text.append(pred_item_tuple)
135
+ elif split_full_text[index] == str(list(wrd.keys())[0]) and text_slices[-1] == _slice:
136
+ index += 1
137
+ pred_item_tuple = list(wrd.items())[0]
138
+ output_text.append(pred_item_tuple)
139
+ assert [i[0] for i in output_text] == split_full_text
140
+ return output_text
141
+
142
+ @staticmethod
143
+ def punctuate_texts(full_pred: list):
144
+ """
145
+ Given a list of Predictions from the model, applies the predictions to text,
146
+ thus punctuating it.
147
+ """
148
+ punct_resp = ""
149
+ for i in full_pred:
150
+ word, label = i
151
+ if label[-1] == "U":
152
+ punct_wrd = word.capitalize()
153
+ else:
154
+ punct_wrd = word
155
+
156
+ if label[0] != "O":
157
+ punct_wrd += label[0]
158
+
159
+ punct_resp += punct_wrd + " "
160
+ punct_resp = punct_resp.strip()
161
+ # Append trailing period if doesnt exist.
162
+ if punct_resp[-1].isalnum():
163
+ punct_resp += "."
164
+ return punct_resp
165
+
166
+
167
+ if __name__ == "__main__":
168
+ punct_model = RestorePuncts()
169
+ # read test file
170
+ with open('../tests/sample_text.txt', 'r') as fp:
171
+ test_sample = fp.read()
172
+ # predict text and print
173
+ punctuated = punct_model.punctuate(test_sample)
174
+ print(punctuated)
myrpunct/utils.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # 💾⚙️🔮
3
+
4
+ __author__ = "Daulet N."
5
+ __email__ = "[email protected]"
6
+
7
+ def prepare_unpunct_text(text):
8
+ """
9
+ Given a text, normalizes it to subsequently restore punctuation
10
+ """
11
+ formatted_txt = text.replace('\n', '').strip()
12
+ formatted_txt = formatted_txt.lower()
13
+ formatted_txt_lst = formatted_txt.split(" ")
14
+ punct_strp_txt = [strip_punct(i) for i in formatted_txt_lst]
15
+ normalized_txt = " ".join([i for i in punct_strp_txt if i])
16
+ return normalized_txt
17
+
18
+ def strip_punct(wrd):
19
+ """
20
+ Given a word, strips non aphanumeric characters that precede and follow it
21
+ """
22
+ if not wrd:
23
+ return wrd
24
+
25
+ while not wrd[-1:].isalnum():
26
+ if not wrd:
27
+ break
28
+ wrd = wrd[:-1]
29
+
30
+ while not wrd[:1].isalnum():
31
+ if not wrd:
32
+ break
33
+ wrd = wrd[1:]
34
+ return wrd
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ transformers
2
+ torch
3
+ langdetect
4
+ simpletransformers
5
+ youtube_transcript_api
sample.srt ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ in 2018 cornell researchers built a
2
+ high-powered detector that in combination
3
+ with an algorithm-driven process called
4
+ ptychography set a world record by tripling
5
+ the resolution of a state-of-the-art electron
6
+ microscope as successful as it was that approach
7
+ had a weakness it only worked with ultrathin
8
+ samples that were a few atoms thick anything
9
+ thicker would cause the electrons to scatter
10
+ in ways that could not be disentangled now a
11
+ team again led by
12
+ david muller
13
+ the samuel beckert professor of engineering
14
+ has bested its own
15
+ record by a factor of two with an electron
16
+ microscope pixel array detector empad that
17
+ incorporates even more sophisticated 3d
18
+ reconstruction algorithms the resolution is so
19
+ fine-tuned the only blurring that remains is
20
+ the thermal jiggling of the atoms themselves