Spaces:
Runtime error
Runtime error
import os | |
import asyncio | |
import numpy as np | |
import gradio as gr | |
from PIL import Image | |
from io import BytesIO | |
from moviepy import VideoFileClip | |
import matplotlib.pyplot as plt | |
import base64 | |
from TStar.TStarFramework import run_tstar | |
def img2base64(image_path): | |
return base64.b64encode(open(image_path, "rb").read()).decode("utf-8") | |
def create_timeline(frame_times, duration): | |
""" | |
Creates a timeline visualization for the sampled frames. | |
""" | |
fig, ax = plt.subplots(figsize=(10, 2)) | |
ax.set_xlim(0, duration) | |
ax.hlines(0.5, 0, duration, colors="gray", linestyles="dotted") | |
ax.plot(frame_times, [0.5] * len(frame_times), 'ro') | |
ax.set_xlabel("Time (s)") | |
buf = BytesIO() | |
fig.savefig(buf, format='png') | |
buf.seek(0) | |
plt.close(fig) | |
return Image.open(buf) | |
def analyze_and_sample_frames( | |
video_file, | |
question, | |
openai_api_key, | |
num_frames=8, | |
batch=1, | |
total_batches=3 | |
): | |
""" | |
结合后端 run_tstar 函数,对视频进行关键帧搜索,并在前端生成可视化所需的结果: | |
- metadata: 记录后端结果与问题、答案等信息 | |
- frames: PIL 图像列表,用于在 Gradio Gallery 显示 | |
- frame_times: 关键帧时间戳列表(秒) | |
- timeline_image: 带有关键帧标注的时间线图像 | |
""" | |
if not os.path.exists(video_file): | |
print("video_file does not exist:", video_file) | |
return None, None, None, None | |
if not question: | |
question = "No question provided" | |
options = "Freeform Question" | |
# 你也可以根据 batch / total_batches 动态改变 search 的参数 | |
# 例如 batch 越大,search_budget 越大;或者直接固定即可 | |
# 这里只做演示,不做复杂逻辑 | |
results = run_tstar( | |
video_path=video_file, | |
question=question, | |
options=options, | |
grounder="gpt-4o", | |
heuristic="owl-vit", | |
device="cuda:0", | |
search_nframes=num_frames, | |
grid_rows=4, | |
grid_cols=4, | |
confidence_threshold=0.6, | |
search_budget=0.5, | |
output_dir='./output', | |
openai_api_key=openai_api_key | |
) | |
# 从后端结果解析关键信息 | |
frame_times = results.get("Frame Timestamps", []) | |
answer = results.get("Answer", "No answer") | |
grounding_objects = results.get("Grounding Objects", []) | |
# 截取关键帧图像 | |
frames = [] | |
clip = VideoFileClip(video_file) | |
video_duration = clip.duration | |
for t in frame_times: | |
# 确保时间戳不超过视频长度 | |
if t > video_duration: | |
t = video_duration | |
frame_img = clip.get_frame(t) # 取对应秒的帧,返回 (H,W,3) numpy | |
frame_pil = Image.fromarray(frame_img.astype(np.uint8)) | |
frames.append(frame_pil) | |
clip.close() | |
# 生成时间线图像 | |
timeline_image = create_timeline(frame_times, duration=video_duration) | |
# 生成元数据(可根据需要增减字段) | |
metadata = { | |
"batch": batch, | |
"total_batches": total_batches, | |
"question": question, | |
"answer": answer, | |
"grounding_objects": grounding_objects, | |
"frame_times": frame_times | |
} | |
return metadata, frames, frame_times, timeline_image | |
def switch_batch(state_batches, selected_batch): | |
""" | |
Switches the display to the selected batch. | |
""" | |
if not selected_batch or selected_batch == "": | |
return None, None, None, None | |
batch_index = int(selected_batch.split()[-1]) - 1 | |
timeline_image, frames, metadata = state_batches[batch_index] | |
return ( | |
gr.update(value=timeline_image, visible=True), | |
gr.update(value=frames, visible=True), | |
gr.update(value=metadata, visible=True), | |
selected_batch, | |
) | |
async def process_video_iteratively_with_state(video_file, question_input, openai_api_key_input, state_batches, current_display_batch, total_batches=1, num_frames=8): | |
""" | |
Processes the video and samples frames iteratively. | |
""" | |
if not video_file: | |
yield None, None, None, "No video uploaded!", None, state_batches, current_display_batch | |
return | |
metadata = None | |
for batch in range(1, total_batches + 1): | |
metadata, frames, frame_times, timeline_image = analyze_and_sample_frames( | |
video_file, question=question_input, openai_api_key=openai_api_key_input, num_frames=num_frames, batch=batch, total_batches=total_batches | |
) | |
if metadata is None: | |
continue | |
state_batches.append((timeline_image, frames, metadata)) | |
batch_choices = [f"Batch {i + 1}" for i in range(len(state_batches))] | |
if current_display_batch is None or current_display_batch == f"Batch {batch - 1}": | |
current_display_batch = f"Batch {batch}" | |
yield ( | |
gr.update(value=timeline_image, visible=True), | |
gr.update(value=frames, visible=True), | |
gr.update(value=metadata, visible=True), | |
f"Processing Batch: {batch} / Total Batches: {total_batches}", | |
gr.update(choices=batch_choices, value=f"Batch {batch}", visible=True), | |
state_batches, | |
current_display_batch, | |
) | |
await asyncio.sleep(0.5) | |
def generate_header(base64_logo, title="⭐ T - Efficient Long Video QA Tool"): | |
""" | |
Generates the header section for the app. | |
""" | |
return f""" | |
<h1 style="text-align: center; font-size: 3em; color: #4CAF50; font-family: 'Open Sans', sans-serif; margin-bottom: 20px;">{title}</h1> | |
<div style="display: flex; justify-content: center; align-items: center; height: 333px;"> | |
<img src="data:image/png;base64,{base64_logo}" alt="Logo" style="width: auto; height: 300px; border-radius: 10px; box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1);"> | |
</div> | |
<div style="display: flex; justify-content: center; align-items: center; margin-top: 20px;"> | |
<h2 style="text-align: center; font-size: 2em; color: #333; margin-bottom: 30px;">📖 How to Use?</h2> | |
</div> | |
""" | |
def generate_instruction(step, title, description): | |
""" | |
Generates a single instruction card. | |
""" | |
return f""" | |
<div style="background-color: #F9F9F9; padding: 20px; border-radius: 10px; box-shadow: 0 2px 4px rgba(0, 0, 0, 0.1); height: 150px; display: flex; flex-direction: column; justify-content: flex-start;"> | |
<h3 style="font-size: 1.5em; color: #4CAF50; font-family: 'Open Sans', sans-serif; margin-bottom: 10px;">Step {step}: {title}</h3> | |
<p style="font-size: 1em; color: #666; line-height: 1.5; margin: 0;"> | |
{description} | |
</p> | |
</div> | |
""" | |
def create_ui_components(default_video_path): | |
""" | |
Creates the UI components for the Gradio application. | |
""" | |
# Layout in two columns | |
with gr.Row(equal_height=True, elem_id="video-container"): | |
with gr.Column(scale=1, min_width=300): | |
# Left Column: Video Upload | |
gr.Markdown(""" | |
<br> | |
<h2 style="color: #333;">Upload Your Video</h3> | |
<p style="color: #666; font-size: 0.9em;">You can upload a sample video or provide your own video for analysis.</p> | |
""") | |
video_input = gr.File( | |
label="Select Video", | |
type="filepath", | |
value=default_video_path, | |
interactive=True | |
) | |
with gr.Column(scale=1.5, min_width=400): | |
# Right Column: Video Preview | |
gr.Markdown(""" | |
<br> | |
<h2 style="color: #333;">Video Preview</h3> | |
<p style="color: #666; font-size: 0.9em;">View your video here before starting the analysis.</p> | |
""") | |
video_preview = gr.Video( | |
label="Preview", | |
value=default_video_path, | |
visible=True, | |
autoplay=True, | |
loop=True, | |
) | |
# add a textbox to input question | |
openai_api_key_input = gr.Textbox( | |
label="Provide your OpenAI API Key", | |
placeholder="sk-...", | |
value="sk-...", | |
type="text", | |
elem_id="openai-api-key-input", | |
) | |
question_input = gr.Textbox( | |
label="Ask a Question", | |
placeholder="", | |
value="Where's the microwave? A. Under the cabinet B. On top of the refrigerator C. Next to the stove D. Beside the sink E. In the pantry", | |
type="text", | |
elem_id="question-input", | |
) | |
submit_button = gr.Button( | |
"Analyze!", | |
elem_id="analyze-button", | |
) | |
# Add a new component for displaying the video preview | |
state_batches = gr.State([]) # Stores all generated batch data | |
current_display_batch = gr.State(None) # Tracks the currently displayed batch | |
output_timeline = gr.Image(label="Video Timeline", type="pil", visible=False) | |
output_frames = gr.Gallery(label="Sampled Frames", columns=8, visible=False, height=200) | |
batch_status = gr.Text(label="Batch Status", value="No Batch Processed Yet", visible=True) | |
batch_selector = gr.Dropdown(choices=[], label="Select Batch", visible=False) | |
output_metadata = gr.JSON(label="Video Metadata", visible=False) | |
return ( | |
openai_api_key_input, | |
video_input, | |
question_input, | |
submit_button, | |
video_preview, # Add the video preview component | |
state_batches, | |
current_display_batch, | |
output_timeline, | |
output_frames, | |
batch_status, | |
batch_selector, | |
output_metadata, | |
) | |
def update_video_preview(video_file, default_video_path): | |
return gr.update(value=(video_file.name if video_file else default_video_path), visible=True, autoplay=True, loop=True) | |
if __name__ == "__main__": | |
# Default sample video path | |
sample_video_path = "data/sample.mp4" | |
logo_path = "data/logo.png" | |
base64_logo = img2base64(logo_path) | |
with gr.Blocks() as demo: | |
# Add header | |
gr.Markdown(generate_header(base64_logo)) | |
# Add instructions | |
steps = [ | |
("Upload", "Sample video is provided. You can also upload your own!<br>Click <strong>Video Preview</strong> to preview it."), | |
("Analyze", "Ask a question and click <strong>'Analyze'</strong>.<br>The system will track keyframes to answer your question."), | |
("Visualize", "View keyframes with their sample distribution.<br>Explore keyframe tracking dynamics visually!"), | |
] | |
with gr.Row(equal_height=True, elem_id="instructions-container"): | |
for i, (title, description) in enumerate(steps, start=1): | |
with gr.Column(scale=1, min_width=100): | |
gr.Markdown(generate_instruction(i, title, description)) | |
( | |
openai_api_key_input, | |
video_input, | |
question_input, | |
submit_button, | |
video_preview, # Video preview component | |
state_batches, | |
current_display_batch, | |
output_timeline, | |
output_frames, | |
batch_status, | |
batch_selector, | |
output_metadata, | |
) = create_ui_components(sample_video_path) | |
video_input.change( | |
fn=update_video_preview, | |
inputs=[video_input, gr.State(sample_video_path)], | |
outputs=video_preview, | |
) | |
submit_button.click( | |
fn=process_video_iteratively_with_state, | |
inputs=[video_input, question_input, openai_api_key_input, state_batches, current_display_batch], | |
outputs=[ | |
output_timeline, | |
output_frames, | |
output_metadata, | |
batch_status, | |
batch_selector, | |
state_batches, | |
current_display_batch, | |
], | |
) | |
batch_selector.change( | |
fn=switch_batch, | |
inputs=[state_batches, batch_selector], | |
outputs=[output_timeline, output_frames, output_metadata, current_display_batch], | |
) | |
# Launch Gradio application | |
demo.launch(share=True, server_name="0.0.0.0", server_port=8088) |