|
import os |
|
import logging |
|
import json |
|
import httpx |
|
from fastapi import FastAPI, Request, HTTPException |
|
from fastapi.responses import StreamingResponse |
|
from fastapi.middleware.cors import CORSMiddleware |
|
|
|
|
|
logging.basicConfig(level=logging.DEBUG) |
|
logger = logging.getLogger(__name__) |
|
|
|
app = FastAPI() |
|
|
|
|
|
origins = [ |
|
|
|
] |
|
|
|
app.add_middleware( |
|
CORSMiddleware, |
|
allow_origins=origins, |
|
allow_credentials=True, |
|
allow_methods=["*"], |
|
allow_headers=["*"], |
|
) |
|
|
|
|
|
API_KEY = os.environ.get("API_KEY", "change_me") |
|
logger.debug(f"Loaded API key: {API_KEY}") |
|
|
|
|
|
OLLAMA_SERVER_URL = "http://localhost:11434/api/generate" |
|
logger.debug(f"Ollama server URL: {OLLAMA_SERVER_URL}") |
|
|
|
|
|
@app.post("/api/generate") |
|
async def generate(request: Request): |
|
"""Endpoint that generates text based on the prompt.""" |
|
try: |
|
|
|
body = await request.json() |
|
model = body.get("model", "hf.co/abanm/Dubs-Q8_0-GGUF:latest") |
|
prompt_text = body.get("prompt", "") |
|
|
|
if not prompt_text: |
|
logger.error("No prompt provided in the request") |
|
raise HTTPException(status_code=400, detail="No prompt provided") |
|
|
|
logger.debug(f"Request body: {body}") |
|
|
|
|
|
auth_header = request.headers.get("Authorization") |
|
logger.debug(f"Received Authorization header: {auth_header}") |
|
|
|
if not auth_header or not auth_header.startswith("Bearer "): |
|
logger.error("Missing or invalid Authorization header") |
|
raise HTTPException(status_code=401, detail="Missing or invalid Authorization header") |
|
|
|
token = auth_header.split(" ")[1] |
|
if token != API_KEY: |
|
logger.error(f"Invalid API key provided: {token}") |
|
raise HTTPException(status_code=401, detail="Invalid API key") |
|
|
|
|
|
payload = {"model": model, "prompt": prompt_text} |
|
logger.debug(f"Payload prepared for Ollama: {payload}") |
|
|
|
|
|
async def stream_response(): |
|
try: |
|
async with httpx.AsyncClient(timeout=httpx.Timeout(60.0)) as client: |
|
async with client.stream( |
|
"POST", OLLAMA_SERVER_URL, json=payload, headers={"Content-Type": "application/json"} |
|
) as response: |
|
logger.info(f"Response status code from Ollama: {response.status_code}") |
|
|
|
if response.status_code != 200: |
|
logger.error(f"HTTP error: {response.status_code} - {await response.text()}") |
|
yield json.dumps({"error": f"HTTP error: {response.status_code}"}) |
|
return |
|
|
|
async for chunk in response.aiter_text(): |
|
logger.debug(f"Chunk received: {chunk}") |
|
yield chunk |
|
except httpx.ReadTimeout: |
|
logger.error("ReadTimeout while waiting for response chunks") |
|
yield json.dumps({"error": "Server response timeout. Try again later."}) |
|
except httpx.RequestError as exc: |
|
logger.error(f"Request error while communicating with Ollama: {str(exc)}") |
|
yield json.dumps({"error": "Network error occurred while communicating with Ollama"}) |
|
except Exception as exc: |
|
logger.exception(f"Unexpected error during streaming: {str(exc)}") |
|
yield json.dumps({"error": "An unexpected error occurred during streaming."}) |
|
|
|
return StreamingResponse(stream_response(), media_type="application/json") |
|
|
|
except Exception as e: |
|
logger.exception(f"Unexpected error: {str(e)}") |
|
raise HTTPException(status_code=500, detail="An unexpected error occurred") |
|
|
|
|
|
@app.get("/health") |
|
async def health(): |
|
"""Health check endpoint.""" |
|
logger.info("Health check endpoint accessed") |
|
return {"status": "OK"} |
|
|
|
|
|
if __name__ == "__main__": |
|
import uvicorn |
|
logger.info("Starting FastAPI application") |
|
uvicorn.run(app, host="0.0.0.0", port=7860) |
|
|