File size: 3,634 Bytes
0685af6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
import json
import chromadb
from datetime import datetime

from utils.general_utils import timeit
from utils.embedding_utils import MyEmbeddingFunction
from youtube_transcript_api import YouTubeTranscriptApi


@timeit
def run_etl(json_path="data/videos.json", db=None, batch_size=None, overlap=None):
    with open(json_path) as f:
        video_info = json.load(f)
        
    videos = []
    for video in video_info:
        video_id = video["id"]
        video_title = video["title"]
        transcript = get_video_transcript(video_id)
        print(f"Transcript for video {video_id} fetched.")
        if transcript:
            formatted_transcript = format_transcript(transcript, video_id, video_title, batch_size=batch_size, overlap=overlap)
            
            videos.extend(formatted_transcript)
    
    if db:
        initialize_db(db)
        load_data_to_db(db, videos)
        log_data_load(json_path, db, batch_size, overlap)
    else:
        print("No database specified. Skipping database load.")
        print(videos)
        

@timeit
def get_video_transcript(video_id):
    try:
        transcript = YouTubeTranscriptApi.get_transcript(video_id, languages=['en', 'en-US'])
        return transcript
    except Exception as e:
        print(f"Error fetching transcript for video {video_id}: {str(e)}")
        return None
    
    
def format_transcript(transcript, video_id, video_title, batch_size=None, overlap=None):
    formatted_data = []
    base_url = f"https://www.youtube.com/watch?v={video_id}"
    query_params = "&t={start}s"
    
    if not batch_size:
        batch_size = 1
        overlap = 0
        
    for i in range(0, len(transcript), batch_size - overlap):
        batch = list(transcript[i:i+batch_size])
        
        start_time = batch[0]["start"]
        
        text = " ".join(entry["text"] for entry in batch)
        
        url = base_url + query_params.format(start=start_time)
        
        metadata = {
            "video_id": video_id,
            "segment_id": video_id + "__" + str(i),
            "title": video_title,
            "source": url
        }
        
        segment = {"text": text, "metadata": metadata}
        
        formatted_data.append(segment)

    return formatted_data


embed_text = MyEmbeddingFunction()

def initialize_db(db_path, distance_metric="cosine"):
    client = chromadb.PersistentClient(path=db_path)
    
    # Clear existing data
    # client.reset()
    
    client.create_collection(
        name="huberman_videos",
        embedding_function=embed_text,
        metadata={"hnsw:space": distance_metric}
    )
    
    print(f"Database created at {db_path}")
    

def load_data_to_db(db_path, data):
    client = chromadb.PersistentClient(path=db_path)
    
    collection = client.get_collection("huberman_videos")

    documents = [segment['text'] for segment in data]
    metadata = [segment['metadata'] for segment in data]
    ids = [segment['metadata']['segment_id'] for segment in data]

    collection.add(
        documents=documents,
        metadatas=metadata,
        ids=ids
    )

    print(f"Data loaded to database at {db_path}.")
    
    
def log_data_load(json_path, db_path, batch_size, overlap):
    log_json = json.dumps({
        "videos_info_path": json_path,
        "db_path": db_path,
        "batch_size": batch_size,
        "overlap": overlap,
        "load_time": str(datetime.now())
    })
    
    db_file = db_path.split("/")[-1]
    db_name = db_file.split(".")[0]
    log_path = f"data/logs/{db_name}_load_log.json"
    
    with open(log_path, "w") as f:
        f.write(log_json)