lalalalalalalalalala commited on
Commit
d05a27e
·
verified ·
1 Parent(s): 11ec732

Update run.py

Browse files
Files changed (1) hide show
  1. run.py +46 -46
run.py CHANGED
@@ -18,54 +18,54 @@ def load_hf_dataset(dataset_path, auth_token):
18
 
19
  def fast_caption(sys_prompt, usr_prompt, temp, top_p, max_tokens, model, key, endpoint, video_src, video_hf, video_hf_auth, parquet_index, video_od, video_od_auth, video_gd, video_gd_auth, frame_format, frame_limit):
20
  progress_info = []
21
- # with tempfile.TemporaryDirectory() as temp_dir:
22
- temp_dir = '/opt/run'
23
  csv_filename = os.path.join(temp_dir, 'caption.csv')
24
  print(csv_filename)
25
- with open(csv_filename, mode='w', newline='') as csv_file:
26
- fieldnames = ['md5', 'caption']
27
- writer = csv.DictWriter(csv_file, fieldnames=fieldnames)
28
- writer.writeheader()
29
-
30
- if video_src:
31
- video = video_src
32
- processor = VideoProcessor(frame_format=frame_format, frame_limit=frame_limit)
33
- frames = processor._decode(video)
34
- base64_list = processor.to_base64_list(frames)
35
- debug_image = processor.concatenate(frames)
36
- if not key or not endpoint:
37
- return "", f"API key or endpoint is missing. Processed {len(frames)} frames.", debug_image
38
- api = AzureAPI(key=key, endpoint=endpoint, model=model, temp=temp, top_p=top_p, max_tokens=max_tokens)
39
- caption = api.get_caption(sys_prompt, usr_prompt, base64_list)
40
- progress_info.append(f"Using model '{model}' with {len(frames)} frames extracted.")
41
- writer.writerow({'md5': 'single_video', 'caption': caption})
42
- return f"{caption}", "\n".join(progress_info), debug_image
43
- elif video_hf and video_hf_auth:
44
- progress_info.append('Begin processing Hugging Face dataset.')
45
- temp_parquet_file = hf_hub_download(
46
- repo_id=video_hf,
47
- filename='data/' + str(parquet_index).zfill(6) + '.parquet',
48
- repo_type="dataset",
49
- token=video_hf_auth,
50
- )
51
- parquet_file = pq.ParquetFile(temp_parquet_file)
52
- for batch in parquet_file.iter_batches(batch_size=1):
53
- df = batch.to_pandas()
54
- video = df['video'][0]
55
- md5 = hashlib.md5(video).hexdigest()
56
- with tempfile.NamedTemporaryFile(dir=temp_dir) as temp_file:
57
- temp_file.write(video)
58
- video_path = temp_file.name
59
- processor = VideoProcessor(frame_format=frame_format, frame_limit=frame_limit)
60
- frames = processor._decode(video_path)
61
- base64_list = processor.to_base64_list(frames)
62
- api = AzureAPI(key=key, endpoint=endpoint, model=model, temp=temp, top_p=top_p, max_tokens=max_tokens)
63
- caption = api.get_caption(sys_prompt, usr_prompt, base64_list)
64
- writer.writerow({'md5': md5, 'caption': caption})
65
- progress_info.append(f"Processed video with MD5: {md5}")
66
- return csv_filename, "\n".join(progress_info), None
67
- else:
68
- return "", "No video source selected.", None
69
 
70
  with gr.Blocks() as Core:
71
  with gr.Row(variant="panel"):
 
18
 
19
  def fast_caption(sys_prompt, usr_prompt, temp, top_p, max_tokens, model, key, endpoint, video_src, video_hf, video_hf_auth, parquet_index, video_od, video_od_auth, video_gd, video_gd_auth, frame_format, frame_limit):
20
  progress_info = []
21
+ with tempfile.TemporaryDirectory() as temp_dir:
22
+ # temp_dir = '/opt/run'
23
  csv_filename = os.path.join(temp_dir, 'caption.csv')
24
  print(csv_filename)
25
+ with open(csv_filename, mode='w', newline='') as csv_file:
26
+ fieldnames = ['md5', 'caption']
27
+ writer = csv.DictWriter(csv_file, fieldnames=fieldnames)
28
+ writer.writeheader()
29
+
30
+ if video_src:
31
+ video = video_src
32
+ processor = VideoProcessor(frame_format=frame_format, frame_limit=frame_limit)
33
+ frames = processor._decode(video)
34
+ base64_list = processor.to_base64_list(frames)
35
+ debug_image = processor.concatenate(frames)
36
+ if not key or not endpoint:
37
+ return "", f"API key or endpoint is missing. Processed {len(frames)} frames.", debug_image
38
+ api = AzureAPI(key=key, endpoint=endpoint, model=model, temp=temp, top_p=top_p, max_tokens=max_tokens)
39
+ caption = api.get_caption(sys_prompt, usr_prompt, base64_list)
40
+ progress_info.append(f"Using model '{model}' with {len(frames)} frames extracted.")
41
+ writer.writerow({'md5': 'single_video', 'caption': caption})
42
+ return f"{caption}", "\n".join(progress_info), debug_image
43
+ elif video_hf and video_hf_auth:
44
+ progress_info.append('Begin processing Hugging Face dataset.')
45
+ temp_parquet_file = hf_hub_download(
46
+ repo_id=video_hf,
47
+ filename='data/' + str(parquet_index).zfill(6) + '.parquet',
48
+ repo_type="dataset",
49
+ token=video_hf_auth,
50
+ )
51
+ parquet_file = pq.ParquetFile(temp_parquet_file)
52
+ for batch in parquet_file.iter_batches(batch_size=1):
53
+ df = batch.to_pandas()
54
+ video = df['video'][0]
55
+ md5 = hashlib.md5(video).hexdigest()
56
+ with tempfile.NamedTemporaryFile(dir=temp_dir) as temp_file:
57
+ temp_file.write(video)
58
+ video_path = temp_file.name
59
+ processor = VideoProcessor(frame_format=frame_format, frame_limit=frame_limit)
60
+ frames = processor._decode(video_path)
61
+ base64_list = processor.to_base64_list(frames)
62
+ api = AzureAPI(key=key, endpoint=endpoint, model=model, temp=temp, top_p=top_p, max_tokens=max_tokens)
63
+ caption = api.get_caption(sys_prompt, usr_prompt, base64_list)
64
+ writer.writerow({'md5': md5, 'caption': caption})
65
+ progress_info.append(f"Processed video with MD5: {md5}")
66
+ return csv_filename, "\n".join(progress_info), None
67
+ else:
68
+ return "", "No video source selected.", None
69
 
70
  with gr.Blocks() as Core:
71
  with gr.Row(variant="panel"):