lalalalalalalalalala commited on
Commit
ee6d0d7
·
verified ·
1 Parent(s): 962d907

Update run.py

Browse files

support hfdata

Files changed (1) hide show
  1. run.py +38 -11
run.py CHANGED
@@ -2,20 +2,47 @@
2
  import gradio as gr
3
  from utils import VideoProcessor, AzureAPI, GoogleAPI, AnthropicAPI, OpenAIAPI
4
  from constraint import SYS_PROMPT, USER_PROMPT
 
5
 
6
- def fast_caption(sys_prompt, usr_prompt, temp, top_p, max_tokens, model, key, endpoint, video, frame_format, frame_limit):
7
- processor = VideoProcessor(frame_format=frame_format, frame_limit=frame_limit)
8
- frames = processor._decode(video)
9
 
10
- base64_list = processor.to_base64_list(frames)
11
- debug_image = processor.concatenate(frames)
 
12
 
13
- if not key or not endpoint:
14
- return "", f"API key or endpoint is missing. Processed {len(frames)} frames.", debug_image
 
 
 
 
 
 
15
 
16
- api = AzureAPI(key=key, endpoint=endpoint, model=model, temp=temp, top_p=top_p, max_tokens=max_tokens)
17
- caption = api.get_caption(sys_prompt, usr_prompt, base64_list)
18
- return f"{caption}", f"Using model '{model}' with {len(frames)} frames extracted.", debug_image
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
 
20
  with gr.Blocks() as Core:
21
  with gr.Row(variant="panel"):
@@ -82,7 +109,7 @@ with gr.Blocks() as Core:
82
  caption_button = gr.Button("Caption", variant="primary", size="lg")
83
  caption_button.click(
84
  fast_caption,
85
- inputs=[sys_prompt, usr_prompt, temp, top_p, max_tokens, model, key, endpoint, video_src, frame_format, frame_limit],
86
  outputs=[result, info, frame]
87
  )
88
 
 
2
  import gradio as gr
3
  from utils import VideoProcessor, AzureAPI, GoogleAPI, AnthropicAPI, OpenAIAPI
4
  from constraint import SYS_PROMPT, USER_PROMPT
5
+ from datasets import load_dataset
6
 
7
+ def load_hf_dataset(dataset_path, auth_token):
8
+ dataset = load_dataset(dataset_path, use_auth_token=auth_token)
 
9
 
10
+ video_paths = dataset
11
+
12
+ return video_paths
13
 
14
+ def fast_caption(sys_prompt, usr_prompt, temp, top_p, max_tokens, model, key, endpoint, video_src, video_hf, video_hf_auth, video_od, video_od_auth, video_gd, video_gd_auth, frame_format, frame_limit):
15
+ if video_src:
16
+ video = video_src
17
+ processor = VideoProcessor(frame_format=frame_format, frame_limit=frame_limit)
18
+ frames = processor._decode(video)
19
+
20
+ base64_list = processor.to_base64_list(frames)
21
+ debug_image = processor.concatenate(frames)
22
 
23
+ if not key or not endpoint:
24
+ return "", f"API key or endpoint is missing. Processed {len(frames)} frames.", debug_image
25
+
26
+ api = AzureAPI(key=key, endpoint=endpoint, model=model, temp=temp, top_p=top_p, max_tokens=max_tokens)
27
+ caption = api.get_caption(sys_prompt, usr_prompt, base64_list)
28
+ return f"{caption}", f"Using model '{model}' with {len(frames)} frames extracted.", debug_image
29
+ elif video_hf and video_hf_auth:
30
+ # Handle Hugging Face dataset
31
+ video_paths = load_hf_dataset(video_hf, video_hf_auth)
32
+ # Process all videos in the dataset
33
+ all_captions = []
34
+ for video_path in video_paths:
35
+ if video_path.endswith('.mp4'): # 假设我们只处理.mp4文件
36
+ processor = VideoProcessor(frame_format=frame_format, frame_limit=frame_limit)
37
+ frames = processor._decode(video_path)
38
+ base64_list = processor.to_base64_list(frames)
39
+ api = AzureAPI(key=key, endpoint=endpoint, model=model, temp=temp, top_p=top_p, max_tokens=max_tokens)
40
+ caption = api.get_caption(sys_prompt, usr_prompt, base64_list)
41
+ all_captions.append(caption)
42
+ return "\n".join(all_captions), f"Processed {len(video_paths)} videos.", None
43
+ # ... (Handle other sources)
44
+ else:
45
+ return "", "No video source selected.", None
46
 
47
  with gr.Blocks() as Core:
48
  with gr.Row(variant="panel"):
 
109
  caption_button = gr.Button("Caption", variant="primary", size="lg")
110
  caption_button.click(
111
  fast_caption,
112
+ inputs=[sys_prompt, usr_prompt, temp, top_p, max_tokens, model, key, endpoint, video_src, video_hf, video_hf_auth, video_od, video_od_auth, video_gd, video_gd_auth, frame_format, frame_limit],
113
  outputs=[result, info, frame]
114
  )
115