File size: 5,274 Bytes
766eb70
 
 
 
 
 
 
 
 
 
1bf6bb5
 
 
 
 
 
 
 
 
 
 
 
766eb70
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1bf6bb5
766eb70
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
import gradio as gr
from youtube_transcript_api import YouTubeTranscriptApi
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline, BitsAndBytesConfig
import re
import os
import torch

# import dotenv
# dotenv.load_dotenv()

if torch.cuda.is_available():
    device = torch.device("cuda")
    print(f"Using GPU: {torch.cuda.get_device_name(device)}")
    import subprocess
    subprocess.run(
        "pip install flash_attn --no-build-isolation --break-system-packages",
        env={"FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"},
        shell=True,
    )
else:
    device = torch.device("cpu")
    print("Using CPU")

# Uncomment and set your Hugging Face token if needed
token = os.environ["HF_TOKEN"]

# Configure 4-bit quantization for model loading
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16,
    attn_implementation="flash_attention_2",
)

# Load the Phi-3 model and tokenizer
print("Loading model and tokenizer...")
model_id = "microsoft/Phi-3-mini-128k-instruct"
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    quantization_config=bnb_config,
    trust_remote_code=True,
    
).to(device)
tokenizer = AutoTokenizer.from_pretrained(model_id)

# Define the system prompt and generation pipeline
system_prompt = "Summarize this YouTube video. Give a brief summary of the video content with the key points and main takeaways."
messages = [{"role": "system", "content": system_prompt}]

pipe = pipeline(
    "text-generation",
    model=model,
    tokenizer=tokenizer,
)

generation_args = {
    "max_new_tokens": 32767,
    "return_full_text": False,
    "do_sample": True,
    "temperature": 0.2,
}

# Function to extract the video ID from a YouTube URL
def extract_video_id(url):
    video_id_match = re.search(r"(?:v=|\/)([0-9A-Za-z_-]{11}).*", url)
    if video_id_match:
        print(f"Extracted video ID: {video_id_match.group(1)}")
        return video_id_match.group(1)
    return None

# Function to get the transcript of a YouTube video
def get_transcript(video_id):
    try:
        transcript = YouTubeTranscriptApi.get_transcript(video_id)
        transcription = [entry['text'] for entry in transcript]
        print(f"Transcript: {transcription}")
        return " ".join(transcription)
    except Exception as e:
        return f"Error fetching transcript: {str(e)}"

# Function to summarize the text using the model
def summarize_text(text):
    messages.append({"role": "user", "content": text})
    output = pipe(messages, **generation_args)
    output = output[0]['generated_text'].strip() # type: ignore
    print(f"Summary: {output}")
    return output

# Main function to process the video URL
def process_video(url):
    video_id = extract_video_id(url)
    if not video_id:
        print("Invalid YouTube URL")
        return "Invalid YouTube URL"
    
    transcript = get_transcript(video_id)
    if transcript.startswith("Error"):
        return transcript
    
    summary = summarize_text(transcript)
    return summary, transcript

# Function to update the embedded video player
def update_embed(url):
    video_id = extract_video_id(url)
    if video_id:
        embed_url = f"https://www.youtube.com/embed/{video_id}"
        return f"<div class='gradio-embed-container'><iframe class='gradio-embed' src='{embed_url}' frameborder='0' allowfullscreen></iframe></div>"
    return "<div class='gradio-embed-container'><iframe class='gradio-embed' src='' frameborder='0' allowfullscreen></iframe></div>"

# Gradio UI setup
with gr.Blocks(css="""
    .gradio-embed-container { position: relative; width: 100%; padding-bottom: 56.25%; height: 0; }
    .gradio-embed { position: absolute; top: 0; left: 0; width: 100%; height: 100%; }
    .small-font { font-size: 0.6em; }
    """) as demo:
    
    gr.Markdown("""
    # YouTube Video Summarizer using Phi-3-mini-128k-instruct
    Summarize any YouTube video using the Phi-3-mini-128k-instruct model.
    """)
    
    with gr.Row():
        with gr.Column(scale=1):
            url = gr.Textbox(
                label="YouTube URL", 
                placeholder="https://www.youtube.com/watch?v=dQw4w9WgXcQ", 
                max_lines=1
            )
            summary = gr.Textbox(
                label="Summary", 
                placeholder="Summary will appear here...", 
                lines=10, 
                show_label=True, 
                show_copy_button=True, 
                elem_classes="small-font"
            )
            transcript = gr.Textbox(
                label="Transcript", 
                placeholder="Transcript will appear here...", 
                lines=1, 
                show_label=True, 
                show_copy_button=True, 
                elem_classes="small-font"
            )
            btn = gr.Button("Summarize")
            btn.click(fn=process_video, inputs=url, outputs=[summary, transcript])
        
        with gr.Column(scale=1):
            video_embed = gr.HTML("<div class='gradio-embed-container'><iframe class='gradio-embed' src='' frameborder='0' allowfullscreen></iframe></div>")
    
    url.change(fn=update_embed, inputs=url, outputs=video_embed)

demo.launch()