# app.py
from fastapi import FastAPI, HTTPException, UploadFile, File, Form
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from llama_cpp import Llama
from typing import Optional
import uvicorn
import huggingface_hub
import os
from PIL import Image
import io
import base64
import logging

# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

app = FastAPI(
    title="OmniVLM API",
    description="API for text and image processing using OmniVLM model",
    version="1.0.0"
)

# Add CORS middleware
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

# Download the model from Hugging Face Hub
try:
    model_path = huggingface_hub.hf_hub_download(
        repo_id="NexaAIDev/OmniVLM-968M",
        filename="omnivision-text-optimized-llm-Q8_0.gguf"
    )
    logger.info(f"Model downloaded successfully to {model_path}")
except Exception as e:
    logger.error(f"Error downloading model: {e}")
    raise

# Initialize the model with the downloaded file
try:
    llm = Llama(
        model_path=model_path,
        n_ctx=2048,
        n_threads=4,
        n_batch=512,
        verbose=True
    )
    logger.info("Model initialized successfully")
except Exception as e:
    logger.error(f"Error initializing model: {e}")
    raise

class GenerationRequest(BaseModel):
    prompt: str
    max_tokens: Optional[int] = 100
    temperature: Optional[float] = 0.7
    top_p: Optional[float] = 0.9

class GenerationResponse(BaseModel):
    generated_text: str
    error: Optional[str] = None

ALLOWED_EXTENSIONS = {'png', 'jpg', 'jpeg', 'gif'}
MAX_IMAGE_SIZE = 10 * 1024 * 1024  # 10MB

def allowed_file(filename):
    return '.' in filename and \
           filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS

@app.post("/generate", response_model=GenerationResponse)
async def generate_text(request: GenerationRequest):
    try:
        output = llm(
            request.prompt,
            max_tokens=request.max_tokens,
            temperature=request.temperature,
            top_p=request.top_p
        )
        
        return GenerationResponse(generated_text=output["choices"][0]["text"])
    except Exception as e:
        logger.error(f"Error in text generation: {e}")
        return GenerationResponse(generated_text="", error=str(e))

@app.post("/process-image", response_model=GenerationResponse)
async def process_image(
    file: UploadFile = File(...),
    prompt: str = Form("Describe this image in detail"),
    max_tokens: int = Form(200),
    temperature: float = Form(0.7)
):
    try:
        # Validate file size
        file_size = 0
        file_content = await file.read()
        file_size = len(file_content)
        
        if file_size > MAX_IMAGE_SIZE:
            raise HTTPException(status_code=400, detail="File too large")
        
        # Validate file type
        if not allowed_file(file.filename):
            raise HTTPException(status_code=400, detail="File type not allowed")
        
        # Process image
        try:
            image = Image.open(io.BytesIO(file_content))
            
            # Convert image to RGB if necessary
            if image.mode != 'RGB':
                image = image.convert('RGB')
            
            # Resize image if too large
            max_size = (1024, 1024)
            if image.size[0] > max_size[0] or image.size[1] > max_size[1]:
                image.thumbnail(max_size, Image.Resampling.LANCZOS)
            
            # Convert to base64
            buffered = io.BytesIO()
            image.save(buffered, format="JPEG", quality=85)
            img_str = base64.b64encode(buffered.getvalue()).decode()
            
            # Create prompt with image
            full_prompt = f"""
            <image>data:image/jpeg;base64,{img_str}</image>
            {prompt}
            """
            
            logger.info("Processing image with prompt")
            # Generate description
            output = llm(
                full_prompt,
                max_tokens=max_tokens,
                temperature=temperature
            )
            
            return GenerationResponse(generated_text=output["choices"][0]["text"])
            
        except Exception as e:
            logger.error(f"Error processing image: {e}")
            raise HTTPException(status_code=500, detail=f"Error processing image: {str(e)}")
            
    except HTTPException as he:
        raise he
    except Exception as e:
        logger.error(f"Unexpected error: {e}")
        return GenerationResponse(generated_text="", error=str(e))

@app.get("/health")
async def health_check():
    return {
        "status": "healthy",
        "model_loaded": llm is not None
    }

if __name__ == "__main__":
    port = int(os.environ.get("PORT", 7860))
    uvicorn.run(app, host="0.0.0.0", port=port, log_level="info")