hebrew-dentsit / rag_agent.py
borodache's picture
Save History as Summary
12a8453 verified
from anthropic import Anthropic
from typing import List
import os
from retriever import Retriever
from reranker import Reranker
retriever = Retriever()
reranker = Reranker()
class RAGAgent:
def __init__(
self,
retriever=retriever,
reranker=reranker,
anthropic_api_key: str = os.environ["anthropic_api_key"],
model_name: str = "claude-3-5-sonnet-20241022",
max_tokens: int = 1024,
temperature: float = 0.0,
):
self.retriever = retriever
self.reranker = reranker
self.client = Anthropic(api_key=anthropic_api_key)
self.model_name = model_name
self.max_tokens = max_tokens
self.temperature = temperature
self.conversation_summary = ""
self.messages = []
def get_context(self, query: str) -> List[str]:
# Get initial candidates from retriever
retrieved_docs = self.retriever.search_similar(query)
# Rerank the candidates
context = self.reranker.rerank(query, retrieved_docs)
return context
def generate_prompt(self, context: List[str], conversation_summary: str = "") -> str:
context = "\n".join(context)
summary_context = f"\nืกื™ื›ื•ื ื”ืฉื™ื—ื” ืขื“ ื›ื”:\n{conversation_summary}" if conversation_summary else ""
prompt = f"""
ืืชื” ืจื•ืคื ืฉื™ื ื™ื™ื, ื“ื•ื‘ืจ ืขื‘ืจื™ืช ื‘ืœื‘ื“. ืงื•ืจืื™ื ืœืš 'ืจื•ืคื ื”ืฉื™ื ื™ื™ื ื”ืืœืงื˜ืจื•ื ื™ ื”ืขื‘ืจื™ ื”ืจืืฉื•ืŸ'.{summary_context}
ืขื ื” ืœืžื˜ื•ืคืœ ืขืœ ื”ืฉืืœื” ืฉืœื• ืขืœ ืกืžืš ื”ืงื•ื ื˜ืงืก ื”ื‘ื: {context}.
ื”ื•ืกืฃ ื›ืžื” ืฉื™ื•ืชืจ ืคืจื˜ื™ื, ื•ื“ืื’ ืฉื”ืชื—ื‘ื™ืจ ื™ื”ื™ื” ืชืงื™ืŸ ื•ื™ืคื”.
ืชืขืฆื•ืจ ื›ืฉืืชื” ืžืจื’ื™ืฉ ืฉืžื™ืฆื™ืช ืืช ืขืฆืžืš. ืืœ ืชืžืฆื™ื ื“ื‘ืจื™ื.
ื•ืืœ ืชืขื ื” ื‘ืฉืคื•ืช ืฉื”ืŸ ืœื ืขื‘ืจื™ืช.
"""
return prompt
def update_summary(self, question: str, answer: str) -> str:
"""Update the conversation summary with the new interaction"""
summary_prompt = {
"model": self.model_name,
"max_tokens": 500,
"temperature": 0.0,
"messages": [
{
"role": "user",
"content": f"""ืกื›ื ืืช ื”ืฉื™ื—ื” ื‘ืขื‘ืจื™ืช, ื”ื ื” ืกื™ื›ื•ื ื”ืฉื™ื—ื” ืขื“ ื›ื”:
{self.conversation_summary if self.conversation_summary else "ืื™ืŸ ืฉื™ื—ื” ืงื•ื“ืžืช."}
ืื™ื ื˜ืจืืงืฆื™ื” ื—ื“ืฉื”:
ืฉืืœืช ื”ืžื˜ื•ืคืœ: {question}
ืชืฉื•ื‘ืช ื”ืจื•ืคื: {answer}
ืื ื ืกืคืง ืกื™ื›ื•ื ืžืขื•ื“ื›ืŸ ืฉื›ื•ืœืœ ืืช ื”ืžื™ื“ืข ื”ืจืคื•ืื™ ืžื”ืกื™ื›ื•ื ื”ืงื•ื“ื ื‘ื ื•ืกืฃ ืœื“ื’ืฉ ืขืœ ื”ืื™ื ื˜ืจืงืฆื™ื” ื”ื—ื“ืฉื”. ื”ืกื™ื›ื•ื ืฆืจื™ืš ืœื”ื™ื•ืช ืชืžืฆื™ืชื™ ืขื“ 100 ืžื™ืœื”.
ื•ืชืจ ืขืœ ืžื™ื“ืข ืœื ืจืœื•ื•ื ื˜ื™ ืžื”ืกื™ื›ื•ืžื™ื ื”ืงื•ื“ืžื™ื"""
}
]
}
try:
response = self.client.messages.create(**summary_prompt)
self.conversation_summary = response.content[0].text
return self.conversation_summary
except Exception as e:
print(f"Error updating summary: {e}")
return self.get_basic_summary()
def get_basic_summary(self) -> str:
"""Fallback method for basic summary"""
summary = []
for i in range(0, len(self.messages), 2):
if i + 1 < len(self.messages):
summary.append(f"ืฉืืœืช ื”ืžื˜ื•ืคืœ: {self.messages[i]['content']}")
summary.append(f"ืชืฉื•ื‘ืช ื”ืจื•ืคื ืฉื™ื ื™ื™ื: {self.messages[i + 1]['content']}\n")
return "\n".join(summary)
def get_response(self, question: str) -> str:
# Get relevant context
context = self.get_context(question + self.conversation_summary)
# Generate prompt with context and current conversation summary
prompt = self.generate_prompt(context, self.conversation_summary)
# Get response from Claude
response = self.client.messages.create(
model=self.model_name,
max_tokens=self.max_tokens,
temperature=self.temperature,
messages=[
{"role": "assistant", "content": prompt},
{"role": "user", "content": f"{question}"}
]
)
answer = response.content[0].text
# Store messages for history
self.messages.extend([
{"role": "user", "content": question},
{"role": "assistant", "content": answer}
])
# Update conversation summary
self.update_summary(question, answer)
return answer