video-qa / rag.py
Thao Pham
First commit
d50ce1c
from langchain.prompts import PromptTemplate
from langchain_community.chat_models import ChatOpenAI
from langchain.chains import LLMChain
from PIL import Image
import os
from utils import load_json_file, str2time
from openai import OpenAI
import base64
def get_smallest_timestamp(timestamps):
assert len(timestamps) > 0
timestamps_in_ms = [str2time(elem) for elem in timestamps]
smallest_timestamp_in_ms = timestamps_in_ms[0]
smallest_timestamp = timestamps[0]
for i, elem in enumerate(timestamps_in_ms):
if elem < smallest_timestamp_in_ms:
smallest_timestamp_in_ms = elem
smallest_timestamp = timestamps[i]
return smallest_timestamp
def generate(query, context, relevant_timestamps=None):
prompt = PromptTemplate(input_variables=["question", "context"], template="You're a helpful LLM assistant in answering questions regarding a video. Given contexts are segments relevant to the question, please answer the question. Do not refer to segments. Context: {context}, question: {question} \nA:")
llm = ChatOpenAI(model_name="gpt-4o-mini", temperature=0)
chain = LLMChain(llm=llm, prompt=prompt)
response = chain.run(question=query, context=context)
if relevant_timestamps is not None and len(relevant_timestamps)>0:
# get smallest timestamp = earliest mention
smallest_timestamp = get_smallest_timestamp(relevant_timestamps)
response += f' {smallest_timestamp}'
return response
def check_relevance(query, relevant_metadatas):
transcripts = [frame['transcript'] for frame in relevant_metadatas]
captions = [frame['caption'] for frame in relevant_metadatas]
timestamps = [frame['start_time'] for frame in relevant_metadatas]
context = ""
for i in range(len(transcripts)):
context += f"Segment {i}: transcript={transcripts[i]} caption={captions[i]}\n"
# print(context)
prompt = PromptTemplate(input_variables=["question", "context"], template="""
You are a grader assessing relevance of a retrieved video segment to a user question. \n
If the video segment contains keyword(s) or semantic meaning related to the question, grade it as relevant. \n
Give a binary score 'yes' or 'no' score to indicate whether the video segment is relevant to the question. \n
Answer in a string, separated by commas. For example: if there are segments provided, answer: yes,no,no,yes. \n
Question: {question} Context: {context}\n A:""")
# query = "What are the books mentioned in the video?"
llm = ChatOpenAI(model_name="gpt-4o-mini", temperature=0)
chain = LLMChain(llm=llm, prompt=prompt)
response = chain.run(question=query, context=context)
# print(response)
relevance_response = response.split(',')
actual_relevant_context = ""
relevant_timestamps = []
for i, relevance_check in enumerate(relevance_response):
if relevance_check.strip() == 'yes':
actual_relevant_context += f"Segment {i}: transcript={transcripts[i]} caption={captions[i]}\n"
relevant_timestamps.append(timestamps[i])
return actual_relevant_context, relevant_timestamps
def retrieve_segments_from_timestamp(metadatas, timestamps):
relevant_segments = []
for timestamp in timestamps:
time_to_find_ms = str2time(timestamp)
buffer = 5000 # 5 seconds before and after
for segment in metadatas:
start = str2time(segment['start_time'])
end = str2time(segment['end_time'])
if start <= time_to_find_ms + buffer and end >= time_to_find_ms - buffer:
relevant_segments.append(segment)
return relevant_segments
def check_timestamps(query):
prompt = PromptTemplate(input_variables=["question"], template="You're a helpful LLM assistant. You're good at detecting any timestamps provided in a query. Please detect the question and timestamp in the the following question and separated them by commas such as question,timestamp1,timestamp2 if timestamps are provided else just question. Question: {question} \nA:")
llm = ChatOpenAI(model_name="gpt-4o-mini", temperature=0)
chain = LLMChain(llm=llm, prompt=prompt)
response = chain.run(question=query)
timestamps = []
if len(response.split(',')) > 1:
query = response.split(',')[0].strip()
timestamps = [f"00:{elem.strip()}.00" for elem in response.split(',')[1:]]
return query, timestamps
def retrieve_by_embedding(index, video_path, query, text_model):
print(query)
query_embedding = text_model.encode(query)
res = index.query(vector=query_embedding.tolist(), top_k=5, filter={"video_path": {"$eq": video_path}} )
metadatas = []
for id, match_ in enumerate(res['matches']):
result = index.fetch(ids=[match_['id']])
# Extract the vector data
vector_data = result.vectors.get(match_['id'], {})
# Extract metadata
metadata = vector_data.metadata
metadatas.append(metadata)
return metadatas
def self_reflection(query, answer, summary):
prompt = PromptTemplate(input_variables=["summary", "question", "answer"], template="You're a helpful LLM assistant. You're good at determining if the provided answer is satisfactory to a question relating to a video. You have access to the video summary as follows: {summary}. Given a pair of question and answer, give the answer's satisfactory score in either yes or no. Question: {question}, Answer: {answer} \nA:")
llm = ChatOpenAI(model_name="gpt-4o-mini", temperature=0)
chain = LLMChain(llm=llm, prompt=prompt)
response = chain.run(summary=summary, question=query, answer=answer)
return response
def get_full_transcript(metadatas):
# metadatas = webvtt.read(path_to_transcript)
transcripts = [frame['transcript'] for frame in metadatas]
full_text = ''
for idx, transcript in enumerate(transcripts):
text = transcript.strip().replace(" ", " ")
full_text += f"{text} "
full_text = full_text.strip()
return full_text
def summarize_video(metadatas_path:str):
metadatas = load_json_file(metadatas_path)
# get full transcript
transcript = get_full_transcript(metadatas)
prompt = PromptTemplate(input_variables=["transcript"], template="You're a helpful LLM assistant. Please provide a summary for the video given its full transcript: {transcript} \nA:")
llm = ChatOpenAI(model_name="gpt-4o-mini", temperature=0)
chain = LLMChain(llm=llm, prompt=prompt)
response = chain.run(transcript=transcript)
return response
def answer_wrt_timestamp(query, context):
prompt = PromptTemplate(input_variables=["question", "context"], template="""
You're a helpful LLM assistant. Given a question and a timestamp, I have retrieved the relevant context as follows. Please answer the question using the information provided in the context. Question: {question}, context: {context} \n
For example: Question="What happens at 4:20?" Caption="a person is standing up" Transcript="I have to go" Appropriate Answer="At 4:20, a person is standing up and saying he has to go."
A:""")
llm = ChatOpenAI(model_name="gpt-4o-mini", temperature=0)
chain = LLMChain(llm=llm, prompt=prompt)
response = chain.run(question=query, context=context)
return response
def answer_question(index, model_stack, metadatas_path, video_summary:str, video_path:str, query:str, image_input_path:str=None):
metadatas = load_json_file(metadatas_path)
if image_input_path is not None:
return answer_image_question(index, model_stack, metadatas, video_summary, video_path, query, image_input_path)
# check if timestamp provided
query, timestamps = check_timestamps(query)
if len(timestamps) > 0:
# retrieve by timestamps
relevant_segments_metadatas = retrieve_segments_from_timestamp(metadatas, timestamps)
transcripts = [frame['transcript'] for frame in relevant_segments_metadatas]
captions = [frame['caption'] for frame in relevant_segments_metadatas]
context = ""
for i in range(len(transcripts)):
context += f"Segment {i}: transcript={transcripts[i]} caption={captions[i]}\n"
# print(context)
return answer_wrt_timestamp(query, context)
else:
# retrieve by embedding
relevant_segments_metadatas = retrieve_by_embedding(index, video_path, query, model_stack[0])
# check relevance
actual_relevant_context, relevant_timestamps = check_relevance(query, relevant_segments_metadatas)
# relevant_timestamps = [frame['start_time'] for frame in relevant_segments_metadatas]
# print(actual_relevant_context)
# generate
answer = generate(query, actual_relevant_context, relevant_timestamps)
# print(answer)
# self-reflection
reflect = self_reflection(query, answer, video_summary)
# print("Reflect", reflect)
if reflect.lower() == 'no':
answer = generate(query, f"{actual_relevant_context}\nSummary={video_summary}")
return answer
def retrieve_segments_by_image_embedding(index, video_path, model_stack, image_query_path):
image_query = Image.open(image_query_path)
_, vision_model, vision_model_processor, _, _ = model_stack
inputs = vision_model_processor(images=image_query, return_tensors="pt")
outputs = vision_model(**inputs)
image_query_embeds = outputs.pooler_output
res = index.query(vector=image_query_embeds.tolist(), top_k=5, filter={"video_path": {"$eq": video_path}} )
metadatas = []
for id_, match_ in enumerate(res['matches']):
result = index.fetch(ids=[match_['id']])
# Extract the vector data
vector_data = result.vectors.get(match_['id'], {})
# Extract metadata
metadata = vector_data.metadata
metadatas.append(metadata)
return metadatas
def answer_image_question(index, model_stack, metadatas, video_summary:str, video_path:str, query:str, image_query_path:str=None):
# search segment by image
relevant_segments = retrieve_segments_by_image_embedding(index, video_path, model_stack, image_query_path)
# generate answer using those segments
return generate_w_image(query, image_query_path, relevant_segments)
def encode_image(image_path):
with open(image_path, "rb") as image_file:
return base64.b64encode(image_file.read()).decode("utf-8")
def generate_w_image(query:str, image_query_path:str, relevant_metadatas):
base64_image = encode_image(image_query_path)
transcripts = [frame['transcript'] for frame in relevant_metadatas]
captions = [frame['caption'] for frame in relevant_metadatas]
# timestamps = [frame['start_time'] for frame in relevant_metadatas]
context = ""
for i in range(len(transcripts)):
context += f"Segment {i}: transcript={transcripts[i]} caption={captions[i]}\n"
# print(context)
client = OpenAI()
response = client.chat.completions.create(
model="gpt-4o-mini",
messages=[
{"role": "user", "content": [
{"type": "text", "text": f"Here is some context about the image: {context}"}, # Add context here
{"type": "text", "text": "You are a helpful LLM assistant. You are good at answering questions about a video given an image. Given the context surrounding the frames most correlated with the image and image, please answer the question. Question: {query}"},
{"type": "image_url", "image_url": {
"url": f"data:image/png;base64,{base64_image}"
}
}
]}
],
temperature=0.0,
max_tokens=100,
)
response = response.choices[0].message.content
# print(response)
return response