rpunct-gr-app / app.py
wldmr's picture
new
42a2568
from myrpunct import RestorePuncts
from youtube_transcript_api import YouTubeTranscriptApi
import gradio as gr
import re
def get_srt(input_link):
if "v=" in input_link:
video_id = input_link.split("v=")[1]
else:
return "Error: Invalid Link, it does not have the pattern 'v=' in it."
print("video_id: ",video_id)
transcript_raw = YouTubeTranscriptApi.get_transcript(video_id)
transcript_text= '\n'.join([i['text'] for i in transcript_raw])
return transcript_text
def predict(input_text, input_file, input_link, input_checkbox):
if input_checkbox=="File" and input_file is not None:
print("Input File ...")
with open(input_file.name) as file:
input_file_read = file.read()
return run_predict(input_file_read)
elif input_checkbox=="Text" and len(input_text) >0:
print("Input Text ...")
return run_predict(input_text)
elif input_checkbox=="Link" and len(input_link)>0:
print("Input Link ...", input_link)
input_link_text = get_srt(input_link)
if "Error" in input_link_text:
return input_link_text
else:
return run_predict(input_link_text)
else:
return "Error: Please provide either an input text or file and select an option accordingly."
def run_predict(input_text):
rpunct = RestorePuncts()
output_text = rpunct.punctuate(input_text)
print("Punctuation finished...")
# restore the carrige returns
srt_file = input_text
punctuated = output_text
srt_file_strip=srt_file.strip()
srt_file_sub=re.sub('\s*\n\s*','# ',srt_file_strip)
srt_file_array=srt_file_sub.split(' ')
pcnt_file_array=punctuated.split(' ')
# goal: restore the break points i.e. the same number of lines as the srt file
# this is necessary, because each line in the srt file corresponds to a frame from the video
if len(srt_file_array)!=len(pcnt_file_array):
return "AssertError: The length of the transcript and the punctuated file should be the same: ",len(srt_file_array),len(pcnt_file_array)
pcnt_file_array_hash = []
for idx, item in enumerate(srt_file_array):
if item.endswith('#'):
pcnt_file_array_hash.append(pcnt_file_array[idx]+'#')
else:
pcnt_file_array_hash.append(pcnt_file_array[idx])
# assemble the array back to a string
pcnt_file_cr=' '.join(pcnt_file_array_hash).replace('#','\n')
return pcnt_file_cr
if __name__ == "__main__":
title = "Rpunct Gradio App"
description = """
<b>Description</b>: <br>
Model restores punctuation and case i.e. of the following punctuations -- [! ? . , - : ; ' ] and also the upper-casing of words. <br>
<b>Usage</b>: <br>
There are three input types any text, a file that can be uploaded or a YouTube video. <br>
Because all three options can be provided by the user (that is you) at the same time <br>
the user has to decisde which input type has to be processed.
"""
article = "Model by [felflare](https://huggingface.co/felflare/bert-restore-punctuation)"
sample_link = "https://www.youtube.com/watch?v=6MI0f6YjJIk"
examples = [["my name is clara and i live in berkeley california", "sample.srt", sample_link, "Text"]]
interface = gr.Interface(fn = predict,
inputs = ["text", "file", "text", gr.Radio(["Text", "File", "Link"], type="value", label='Input Type')],
outputs = ["text"],
title = title,
description = description,
article = article,
examples=examples,
allow_flagging="never")
interface.launch()
# save flagging to a hf dataset
# https://github.com/gradio-app/gradio/issues/914
# the best option here is to use a Hugging Face dataset as the storage for flagged data. And to do that, please check out the HuggingFaceDatasetSaver() flagging handler, which allows you to do that easily.
#Here is an example Space that uses this: https://huggingface.co/spaces/abidlabs/crowd-speech