Spaces:
Running
Running
from llama_index.core.workflow import Workflow, Context, StartEvent, StopEvent, step | |
import json | |
from prompts import PRECEDENT_ANALYSIS_TEMPLATE | |
from enum import Enum | |
from anthropic import Anthropic | |
# from llama_index.llms.openai import OpenAI | |
from openai import OpenAI | |
from llama_index.core.llms import ChatMessage | |
from config import embed_model, Settings, openai_api_key, anthropic_api_key | |
class ModelProvider(str, Enum): | |
OPENAI = "openai" | |
ANTHROPIC = "anthropic" | |
class ModelName(str, Enum): | |
# OpenAI models | |
GPT4o = "gpt-4o" | |
GPT4o_MINI = "gpt-4o-mini" | |
# Anthropic models | |
CLAUDE3_5_SONNET = "claude-3-5-sonnet-latest" | |
CLAUDE3_5_HAIKU = "claude-3-5-haiku-latest" | |
class LLMAnalyzer: | |
def __init__(self, provider: ModelProvider, model_name: ModelName): | |
self.provider = provider | |
self.model_name = model_name | |
if provider == ModelProvider.OPENAI: | |
self.client = OpenAI() | |
elif provider == ModelProvider.ANTHROPIC: | |
# Додаємо API ключ при ініціалізації | |
self.client = Anthropic(api_key=anthropic_api_key) | |
else: | |
raise ValueError(f"Unsupported provider: {provider}") | |
async def analyze(self, prompt: str, response_schema: dict) -> str: | |
if self.provider == ModelProvider.OPENAI: | |
return await self._analyze_with_openai(prompt, response_schema) | |
else: | |
return await self._analyze_with_anthropic(prompt, response_schema) | |
async def _analyze_with_openai(self, prompt: str, response_schema: dict) -> str: | |
# Правильний формат для response_format | |
response_format = { | |
"type": "json_schema", | |
"json_schema": { | |
"name": "relevant_positions_schema", # Додаємо обов'язкове поле name | |
"schema": response_schema | |
} | |
} | |
response = self.client.chat.completions.create( | |
model=self.model_name, | |
messages=[ | |
{ | |
"role": "system", | |
"content": [ | |
{ | |
"type": "text", | |
"text": "Ти - кваліфікований юрист-аналітик, експерт з правових позицій Верховного Суду." | |
} | |
] | |
}, | |
{ | |
"role": "user", | |
"content": [ | |
{ | |
"type": "text", | |
"text": prompt | |
} | |
] | |
} | |
], | |
response_format=response_format, | |
temperature=0, | |
max_tokens=4096 | |
) | |
return response.choices[0].message.content | |
async def _analyze_with_anthropic(self, prompt: str, response_schema: dict) -> str: | |
response = self.client.messages.create( # Прибрали await | |
model=self.model_name, | |
max_tokens=4096, | |
messages=[ | |
{ | |
"role": "assistant", | |
"content": "Ти - кваліфікований юрист-аналітик, експерт з правових позицій Верховного Суду." | |
}, | |
{ | |
"role": "user", | |
"content": prompt | |
} | |
] | |
) | |
return response.content[0].text | |
class PrecedentAnalysisWorkflow(Workflow): | |
def __init__(self, provider: ModelProvider = ModelProvider.OPENAI, | |
model_name: ModelName = ModelName.GPT4o_MINI): | |
super().__init__() | |
self.analyzer = LLMAnalyzer(provider, model_name) | |
async def analyze(self, ctx: Context, ev: StartEvent) -> StopEvent: | |
try: | |
# Отримуємо параметри з події з дефолтними значеннями | |
query = ev.get("query", "") | |
question = ev.get("question", "") | |
nodes = ev.get("nodes", []) | |
# Перевірка на пусті значення | |
if not query: | |
return StopEvent(result="Помилка: Не надано текст нового рішення (query)") | |
if not nodes: | |
return StopEvent(result="Помилка: Не надано правові позиції для аналізу (nodes)") | |
# Підготовка контексту | |
context_parts = [] | |
for i, node in enumerate(nodes, 1): | |
node_text = node.node.text if hasattr(node, 'node') else node.text | |
metadata = node.node.metadata if hasattr(node, 'node') else node.metadata | |
lp_id = metadata.get('lp_id', f'unknown_{i}') | |
context_parts.append(f"Source {i} (ID: {lp_id}):\n{node_text}") | |
context_str = "\n\n".join(context_parts) | |
# Схема відповіді | |
response_schema = { | |
"type": "object", | |
"properties": { | |
"relevant_positions": { | |
"type": "array", | |
"items": { | |
"type": "object", | |
"properties": { | |
"lp_id": {"type": "string"}, | |
"source_index": {"type": "string"}, | |
"description": {"type": "string"} | |
}, | |
"required": ["lp_id", "source_index", "description"] | |
} | |
} | |
}, | |
"required": ["relevant_positions"] | |
} | |
# Формування промпту | |
prompt = PRECEDENT_ANALYSIS_TEMPLATE.format( | |
query=query, | |
question=question if question else "Загальний аналіз релевантності", | |
context_str=context_str | |
) | |
# Отримання відповіді від моделі | |
response_content = await self.analyzer.analyze(prompt, response_schema) | |
try: | |
parsed_response = json.loads(response_content) | |
if "relevant_positions" in parsed_response: | |
response_lines = [] | |
for position in parsed_response["relevant_positions"]: | |
position_text = ( | |
f"* [{position['source_index']}] {position['description']} " | |
) | |
response_lines.append(position_text) | |
response_text = "\n".join(response_lines) | |
return StopEvent(result=response_text) | |
else: | |
return StopEvent(result="Не знайдено релевантних правових позицій") | |
except json.JSONDecodeError: | |
return StopEvent(result=f"**Аналіз ШІ (модель: {self.analyzer.model_name}):** {response_content}") | |
except Exception as e: | |
return StopEvent(result=f"Error during analysis: {str(e)}") |