Spaces:
Runtime error
Runtime error
File size: 3,634 Bytes
0685af6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 |
import json
import chromadb
from datetime import datetime
from utils.general_utils import timeit
from utils.embedding_utils import MyEmbeddingFunction
from youtube_transcript_api import YouTubeTranscriptApi
@timeit
def run_etl(json_path="data/videos.json", db=None, batch_size=None, overlap=None):
with open(json_path) as f:
video_info = json.load(f)
videos = []
for video in video_info:
video_id = video["id"]
video_title = video["title"]
transcript = get_video_transcript(video_id)
print(f"Transcript for video {video_id} fetched.")
if transcript:
formatted_transcript = format_transcript(transcript, video_id, video_title, batch_size=batch_size, overlap=overlap)
videos.extend(formatted_transcript)
if db:
initialize_db(db)
load_data_to_db(db, videos)
log_data_load(json_path, db, batch_size, overlap)
else:
print("No database specified. Skipping database load.")
print(videos)
@timeit
def get_video_transcript(video_id):
try:
transcript = YouTubeTranscriptApi.get_transcript(video_id, languages=['en', 'en-US'])
return transcript
except Exception as e:
print(f"Error fetching transcript for video {video_id}: {str(e)}")
return None
def format_transcript(transcript, video_id, video_title, batch_size=None, overlap=None):
formatted_data = []
base_url = f"https://www.youtube.com/watch?v={video_id}"
query_params = "&t={start}s"
if not batch_size:
batch_size = 1
overlap = 0
for i in range(0, len(transcript), batch_size - overlap):
batch = list(transcript[i:i+batch_size])
start_time = batch[0]["start"]
text = " ".join(entry["text"] for entry in batch)
url = base_url + query_params.format(start=start_time)
metadata = {
"video_id": video_id,
"segment_id": video_id + "__" + str(i),
"title": video_title,
"source": url
}
segment = {"text": text, "metadata": metadata}
formatted_data.append(segment)
return formatted_data
embed_text = MyEmbeddingFunction()
def initialize_db(db_path, distance_metric="cosine"):
client = chromadb.PersistentClient(path=db_path)
# Clear existing data
# client.reset()
client.create_collection(
name="huberman_videos",
embedding_function=embed_text,
metadata={"hnsw:space": distance_metric}
)
print(f"Database created at {db_path}")
def load_data_to_db(db_path, data):
client = chromadb.PersistentClient(path=db_path)
collection = client.get_collection("huberman_videos")
documents = [segment['text'] for segment in data]
metadata = [segment['metadata'] for segment in data]
ids = [segment['metadata']['segment_id'] for segment in data]
collection.add(
documents=documents,
metadatas=metadata,
ids=ids
)
print(f"Data loaded to database at {db_path}.")
def log_data_load(json_path, db_path, batch_size, overlap):
log_json = json.dumps({
"videos_info_path": json_path,
"db_path": db_path,
"batch_size": batch_size,
"overlap": overlap,
"load_time": str(datetime.now())
})
db_file = db_path.split("/")[-1]
db_name = db_file.split(".")[0]
log_path = f"data/logs/{db_name}_load_log.json"
with open(log_path, "w") as f:
f.write(log_json) |