"""Extract structured data from unstructured text with an LLM.""" from typing import Any, TypeVar from litellm import completion from pydantic import BaseModel, ValidationError from raglite._config import RAGLiteConfig T = TypeVar("T", bound=BaseModel) def extract_with_llm( return_type: type[T], user_prompt: str | list[str], config: RAGLiteConfig | None = None, **kwargs: Any, ) -> T: """Extract structured data from unstructured text with an LLM. This function expects a `return_type.system_prompt: ClassVar[str]` that contains the system prompt to use. Example: from typing import ClassVar from pydantic import BaseModel, Field class MyNameResponse(BaseModel): my_name: str = Field(..., description="The user's name.") system_prompt: ClassVar[str] = "The system prompt to use (excluded from JSON schema)." my_name_response = extract_with_llm(MyNameResponse, "My name is Thomas A. Anderson.") """ # Load the default config if not provided. config = config or RAGLiteConfig() # Update the system prompt with the JSON schema of the return type to help the LLM. system_prompt = ( return_type.system_prompt.strip() + "\n", # type: ignore[attr-defined] "Format your response according to this JSON schema:\n", return_type.model_json_schema(), ) # Concatenate the user prompt if it is a list of strings. if isinstance(user_prompt, list): user_prompt = "\n\n".join( f'\n{chunk.strip()}\n' for i, chunk in enumerate(user_prompt) ) # Extract structured data from the unstructured input. for _ in range(config.llm_max_tries): response = completion( model=config.llm, messages=[ {"role": "system", "content": system_prompt}, {"role": "user", "content": user_prompt}, ], response_format={"type": "json_object", "schema": return_type.model_json_schema()}, **kwargs, ) try: instance = return_type.model_validate_json(response["choices"][0]["message"]["content"]) except (KeyError, ValueError, ValidationError) as e: # Malformed response, not a JSON string, or not a valid instance of the return type. last_exception = e continue else: break else: error_message = f"Failed to extract {return_type} from input {user_prompt}." raise ValueError(error_message) from last_exception return instance