DocUA's picture
edit analysis & prompts
060f57e
raw
history blame
7.32 kB
from llama_index.core.workflow import Workflow, Context, StartEvent, StopEvent, step
import json
from prompts import PRECEDENT_ANALYSIS_TEMPLATE
from enum import Enum
from anthropic import Anthropic
# from llama_index.llms.openai import OpenAI
from openai import OpenAI
from llama_index.core.llms import ChatMessage
from config import embed_model, Settings, openai_api_key, anthropic_api_key
class ModelProvider(str, Enum):
OPENAI = "openai"
ANTHROPIC = "anthropic"
class ModelName(str, Enum):
# OpenAI models
GPT4o = "gpt-4o"
GPT4o_MINI = "gpt-4o-mini"
# Anthropic models
CLAUDE3_5_SONNET = "claude-3-5-sonnet-latest"
CLAUDE3_5_HAIKU = "claude-3-5-haiku-latest"
class LLMAnalyzer:
def __init__(self, provider: ModelProvider, model_name: ModelName):
self.provider = provider
self.model_name = model_name
if provider == ModelProvider.OPENAI:
self.client = OpenAI()
elif provider == ModelProvider.ANTHROPIC:
# Додаємо API ключ при ініціалізації
self.client = Anthropic(api_key=anthropic_api_key)
else:
raise ValueError(f"Unsupported provider: {provider}")
async def analyze(self, prompt: str, response_schema: dict) -> str:
if self.provider == ModelProvider.OPENAI:
return await self._analyze_with_openai(prompt, response_schema)
else:
return await self._analyze_with_anthropic(prompt, response_schema)
async def _analyze_with_openai(self, prompt: str, response_schema: dict) -> str:
# Правильний формат для response_format
response_format = {
"type": "json_schema",
"json_schema": {
"name": "relevant_positions_schema", # Додаємо обов'язкове поле name
"schema": response_schema
}
}
response = self.client.chat.completions.create(
model=self.model_name,
messages=[
{
"role": "system",
"content": [
{
"type": "text",
"text": "Ти - кваліфікований юрист-аналітик, експерт з правових позицій Верховного Суду."
}
]
},
{
"role": "user",
"content": [
{
"type": "text",
"text": prompt
}
]
}
],
response_format=response_format,
temperature=0,
max_tokens=4096
)
return response.choices[0].message.content
async def _analyze_with_anthropic(self, prompt: str, response_schema: dict) -> str:
response = self.client.messages.create( # Прибрали await
model=self.model_name,
max_tokens=4096,
messages=[
{
"role": "assistant",
"content": "Ти - кваліфікований юрист-аналітик, експерт з правових позицій Верховного Суду."
},
{
"role": "user",
"content": prompt
}
]
)
return response.content[0].text
class PrecedentAnalysisWorkflow(Workflow):
def __init__(self, provider: ModelProvider = ModelProvider.OPENAI,
model_name: ModelName = ModelName.GPT4o_MINI):
super().__init__()
self.analyzer = LLMAnalyzer(provider, model_name)
@step
async def analyze(self, ctx: Context, ev: StartEvent) -> StopEvent:
try:
# Отримуємо параметри з події з дефолтними значеннями
query = ev.get("query", "")
question = ev.get("question", "")
nodes = ev.get("nodes", [])
# Перевірка на пусті значення
if not query:
return StopEvent(result="Помилка: Не надано текст нового рішення (query)")
if not nodes:
return StopEvent(result="Помилка: Не надано правові позиції для аналізу (nodes)")
# Підготовка контексту
context_parts = []
for i, node in enumerate(nodes, 1):
node_text = node.node.text if hasattr(node, 'node') else node.text
metadata = node.node.metadata if hasattr(node, 'node') else node.metadata
lp_id = metadata.get('lp_id', f'unknown_{i}')
context_parts.append(f"Source {i} (ID: {lp_id}):\n{node_text}")
context_str = "\n\n".join(context_parts)
# Схема відповіді
response_schema = {
"type": "object",
"properties": {
"relevant_positions": {
"type": "array",
"items": {
"type": "object",
"properties": {
"lp_id": {"type": "string"},
"source_index": {"type": "string"},
"description": {"type": "string"}
},
"required": ["lp_id", "source_index", "description"]
}
}
},
"required": ["relevant_positions"]
}
# Формування промпту
prompt = PRECEDENT_ANALYSIS_TEMPLATE.format(
query=query,
question=question if question else "Загальний аналіз релевантності",
context_str=context_str
)
# Отримання відповіді від моделі
response_content = await self.analyzer.analyze(prompt, response_schema)
try:
parsed_response = json.loads(response_content)
if "relevant_positions" in parsed_response:
response_lines = []
for position in parsed_response["relevant_positions"]:
position_text = (
f"* [{position['source_index']}] {position['description']} "
)
response_lines.append(position_text)
response_text = "\n".join(response_lines)
return StopEvent(result=response_text)
else:
return StopEvent(result="Не знайдено релевантних правових позицій")
except json.JSONDecodeError:
return StopEvent(result=f"**Аналіз ШІ (модель: {self.analyzer.model_name}):** {response_content}")
except Exception as e:
return StopEvent(result=f"Error during analysis: {str(e)}")