AZAA / main.py
rkihacker's picture
Update main.py
db3a9bd verified
from fastapi import FastAPI, Request
from fastapi.responses import StreamingResponse, JSONResponse
from pydantic import BaseModel
from typing import List, Optional, Union
import requests
import json
import logging
app = FastAPI()
# Logging setup
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("proxy")
# Configuration
API_URL = "https://api.typegpt.net/v1/chat/completions"
API_KEY = "sk-XzS5hhsa3vpIcRLz3prQirBQXOx2hPydPzSpzdRcE1YddnNm"
BACKEND_MODEL = "gpt-4o-mini"
# Load system prompt mappings
with open("model_map.json", "r", encoding="utf-8") as f:
MODEL_PROMPTS = json.load(f)
# Request schema
# Define ContentType for vision
class ContentImage(BaseModel):
type: str # must be "image_url"
image_url: dict # {"url": "https://..." or "data:image/...;base64,..."}
class ContentText(BaseModel):
type: str # must be "text"
text: str
ContentType = Union[ContentText, ContentImage]
# Message model allows BOTH old and new formats
class Message(BaseModel):
role: str
content: Union[str, List[ContentType]] # str (legacy) or list of ContentType
# ChatRequest model
class ChatRequest(BaseModel):
model: str
messages: List[Message]
stream: Optional[bool] = False
temperature: Optional[float] = 1.0
top_p: Optional[float] = 1.0
n: Optional[int] = 1
stop: Optional[Union[str, List[str]]] = None
presence_penalty: Optional[float] = 0.0
frequency_penalty: Optional[float] = 0.0
# Build request to backend with injected system prompt
def build_payload(chat: ChatRequest):
system_prompt = MODEL_PROMPTS.get(chat.model, "You are a helpful assistant.")
filtered_messages = [msg for msg in chat.messages if msg.role != "system"]
payload_messages = [{"role": "system", "content": system_prompt}]
for msg in filtered_messages:
# Legacy format: string
if isinstance(msg.content, str):
payload_messages.append({"role": msg.role, "content": msg.content})
# Multimodal format
elif isinstance(msg.content, list):
content_payload = []
for content_item in msg.content:
# ContentText
if content_item.type == "text":
content_payload.append({
"type": "text",
"text": content_item.text
})
# ContentImage
elif content_item.type == "image_url":
content_payload.append({
"type": "image_url",
"image_url": content_item.image_url
})
else:
logger.warning(f"Unknown content type: {content_item.type}, skipping.")
payload_messages.append({"role": msg.role, "content": content_payload})
else:
logger.warning(f"Unknown message content format: {msg.content}")
return {
"model": BACKEND_MODEL,
"messages": payload_messages,
"stream": chat.stream,
"temperature": chat.temperature,
"top_p": chat.top_p,
"n": chat.n,
"stop": chat.stop,
"presence_penalty": chat.presence_penalty,
"frequency_penalty": chat.frequency_penalty
}
# Stream generator without forcing UTF-8
def stream_generator(requested_model: str, payload: dict, headers: dict):
with requests.post(API_URL, headers=headers, json=payload, stream=True) as r:
for line in r.iter_lines(decode_unicode=False): # Keep as bytes
if not line:
continue
if line.startswith(b"data:"):
content = line[6:].strip()
if content == b"[DONE]":
yield b"data: [DONE]\n\n"
continue
try:
json_obj = json.loads(content.decode("utf-8"))
if json_obj.get("model") == BACKEND_MODEL:
json_obj["model"] = requested_model
yield f"data: {json.dumps(json_obj)}\n\n".encode("utf-8")
except json.JSONDecodeError:
logger.warning("Invalid JSON in stream chunk: %s", content)
else:
logger.debug("Non-data stream line skipped: %s", line)
# Main endpoint
@app.post("/v1/chat/completions")
async def proxy_chat(request: Request):
try:
body = await request.json()
chat_request = ChatRequest(**body)
payload = build_payload(chat_request)
headers = {
"Authorization": f"Bearer {API_KEY}",
"Content-Type": "application/json"
}
if chat_request.stream:
return StreamingResponse(
stream_generator(chat_request.model, payload, headers),
media_type="text/event-stream"
)
else:
response = requests.post(API_URL, headers=headers, json=payload)
response.raise_for_status() # Raise error for bad responses
data = response.json()
if "model" in data and data["model"] == BACKEND_MODEL:
data["model"] = chat_request.model
return JSONResponse(content=data)
except Exception as e:
logger.error("Error in /v1/chat/completions: %s", str(e))
return JSONResponse(
content={"error": "Internal server error."},
status_code=500
)