Spaces:
Running
Running
from anthropic import Anthropic | |
from typing import List | |
import os | |
from retriever import Retriever | |
from reranker import Reranker | |
retriever = Retriever() | |
reranker = Reranker() | |
class RAGAgent: | |
def __init__( | |
self, | |
retriever=retriever, | |
reranker=reranker, | |
anthropic_api_key: str = os.environ["anthropic_api_key"], | |
model_name: str = "claude-3-5-sonnet-20241022", | |
max_tokens: int = 1024, | |
temperature: float = 0.0, | |
): | |
self.retriever = retriever | |
self.reranker = reranker | |
self.client = Anthropic(api_key=anthropic_api_key) | |
self.model_name = model_name | |
self.max_tokens = max_tokens | |
self.temperature = temperature | |
self.conversation_summary = "" | |
self.messages = [] | |
def get_context(self, query: str) -> List[str]: | |
# Get initial candidates from retriever | |
retrieved_docs = self.retriever.search_similar(query) | |
# Rerank the candidates | |
context = self.reranker.rerank(query, retrieved_docs) | |
return context | |
def generate_prompt(self, context: List[str], conversation_summary: str = "") -> str: | |
context = "\n".join(context) | |
summary_context = f"\nืกืืืื ืืฉืืื ืขื ืื:\n{conversation_summary}" if conversation_summary else "" | |
prompt = f""" | |
ืืชื ืจืืคื ืฉืื ืืื, ืืืืจ ืขืืจืืช ืืืื. ืงืืจืืื ืื 'ืจืืคื ืืฉืื ืืื ืืืืงืืจืื ื ืืขืืจื ืืจืืฉืื'.{summary_context} | |
ืขื ื ืืืืืคื ืขื ืืฉืืื ืฉืื ืขื ืกืื ืืงืื ืืงืก ืืื: {context}. | |
ืืืกืฃ ืืื ืฉืืืชืจ ืคืจืืื, ืืืื ืฉืืชืืืืจ ืืืื ืชืงืื ืืืคื. | |
ืชืขืฆืืจ ืืฉืืชื ืืจืืืฉ ืฉืืืฆืืช ืืช ืขืฆืื. ืื ืชืืฆืื ืืืจืื. | |
ืืื ืชืขื ื ืืฉืคืืช ืฉืื ืื ืขืืจืืช. | |
""" | |
return prompt | |
def update_summary(self, question: str, answer: str) -> str: | |
"""Update the conversation summary with the new interaction""" | |
summary_prompt = { | |
"model": self.model_name, | |
"max_tokens": 500, | |
"temperature": 0.0, | |
"messages": [ | |
{ | |
"role": "user", | |
"content": f"""ืกืื ืืช ืืฉืืื ืืขืืจืืช, ืื ื ืกืืืื ืืฉืืื ืขื ืื: | |
{self.conversation_summary if self.conversation_summary else "ืืื ืฉืืื ืงืืืืช."} | |
ืืื ืืจืืงืฆืื ืืืฉื: | |
ืฉืืืช ืืืืืคื: {question} | |
ืชืฉืืืช ืืจืืคื: {answer} | |
ืื ื ืกืคืง ืกืืืื ืืขืืืื ืฉืืืื ืืช ืืืืืข ืืจืคืืื ืืืกืืืื ืืงืืื ืื ืืกืฃ ืืืืฉ ืขื ืืืื ืืจืงืฆืื ืืืืฉื. ืืกืืืื ืฆืจืื ืืืืืช ืชืืฆืืชื ืขื 100 ืืืื. | |
ืืชืจ ืขื ืืืืข ืื ืจืืืื ืื ืืืกืืืืืื ืืงืืืืื""" | |
} | |
] | |
} | |
try: | |
response = self.client.messages.create(**summary_prompt) | |
self.conversation_summary = response.content[0].text | |
return self.conversation_summary | |
except Exception as e: | |
print(f"Error updating summary: {e}") | |
return self.get_basic_summary() | |
def get_basic_summary(self) -> str: | |
"""Fallback method for basic summary""" | |
summary = [] | |
for i in range(0, len(self.messages), 2): | |
if i + 1 < len(self.messages): | |
summary.append(f"ืฉืืืช ืืืืืคื: {self.messages[i]['content']}") | |
summary.append(f"ืชืฉืืืช ืืจืืคื ืฉืื ืืื: {self.messages[i + 1]['content']}\n") | |
return "\n".join(summary) | |
def get_response(self, question: str) -> str: | |
# Get relevant context | |
context = self.get_context(question + self.conversation_summary) | |
# Generate prompt with context and current conversation summary | |
prompt = self.generate_prompt(context, self.conversation_summary) | |
# Get response from Claude | |
response = self.client.messages.create( | |
model=self.model_name, | |
max_tokens=self.max_tokens, | |
temperature=self.temperature, | |
messages=[ | |
{"role": "assistant", "content": prompt}, | |
{"role": "user", "content": f"{question}"} | |
] | |
) | |
answer = response.content[0].text | |
# Store messages for history | |
self.messages.extend([ | |
{"role": "user", "content": question}, | |
{"role": "assistant", "content": answer} | |
]) | |
# Update conversation summary | |
self.update_summary(question, answer) | |
return answer | |