Spaces:
Runtime error
Runtime error
from datetime import datetime | |
from typing import Any, Dict, List, Literal, Optional | |
from pydantic import BaseModel, Field, validator | |
from reworkd_platform.web.api.agent.analysis import Analysis | |
LLM_Model = Literal[ | |
"gpt-3.5-turbo", | |
"gpt-3.5-turbo-16k", | |
"gpt-4", | |
] | |
Loop_Step = Literal[ | |
"start", | |
"analyze", | |
"execute", | |
"create", | |
"summarize", | |
"chat", | |
] | |
LLM_MODEL_MAX_TOKENS: Dict[LLM_Model, int] = { | |
"gpt-3.5-turbo": 4000, | |
"gpt-3.5-turbo-16k": 16000, | |
"gpt-4": 8000, | |
} | |
class ModelSettings(BaseModel): | |
model: LLM_Model = Field(default="gpt-3.5-turbo") | |
custom_api_key: Optional[str] = Field(default=None) | |
temperature: float = Field(default=0.9, ge=0.0, le=1.0) | |
max_tokens: int = Field(default=500, ge=0) | |
language: str = Field(default="English") | |
def validate_max_tokens(cls, v: float, values: Dict[str, Any]) -> float: | |
model = values["model"] | |
if v > (max_tokens := LLM_MODEL_MAX_TOKENS[model]): | |
raise ValueError(f"Model {model} only supports {max_tokens} tokens") | |
return v | |
class AgentRunCreate(BaseModel): | |
goal: str | |
model_settings: ModelSettings = Field(default=ModelSettings()) | |
class AgentRun(AgentRunCreate): | |
run_id: str | |
class AgentTaskAnalyze(AgentRun): | |
task: str | |
tool_names: List[str] = Field(default=[]) | |
model_settings: ModelSettings = Field(default=ModelSettings()) | |
class AgentTaskExecute(AgentRun): | |
task: str | |
analysis: Analysis | |
class AgentTaskCreate(AgentRun): | |
tasks: List[str] = Field(default=[]) | |
last_task: Optional[str] = Field(default=None) | |
result: Optional[str] = Field(default=None) | |
completed_tasks: List[str] = Field(default=[]) | |
class AgentSummarize(AgentRun): | |
results: List[str] = Field(default=[]) | |
class AgentChat(AgentRun): | |
message: str | |
results: List[str] = Field(default=[]) | |
class NewTasksResponse(BaseModel): | |
run_id: str | |
new_tasks: List[str] = Field(alias="newTasks") | |
class RunCount(BaseModel): | |
count: int | |
first_run: Optional[datetime] | |
last_run: Optional[datetime] | |