Spaces:
Sleeping
Sleeping
from fastapi import FastAPI, HTTPException, UploadFile, File, Form, Depends | |
from typing import Optional | |
from fastapi.responses import StreamingResponse | |
from huggingface_hub import InferenceClient | |
from pydantic import BaseModel, ConfigDict | |
import os | |
from base64 import b64encode | |
from io import BytesIO | |
from PIL import Image, ImageEnhance | |
import logging | |
import pytesseract | |
import time | |
app = FastAPI() | |
# Configure logging | |
logging.basicConfig(level=logging.DEBUG) | |
logger = logging.getLogger(__name__) | |
# Default model | |
DEFAULT_MODEL = "meta-llama/Meta-Llama-3-8B-Instruct" | |
class TextRequest(BaseModel): | |
model_config = ConfigDict(protected_namespaces=()) | |
query: str | |
stream: bool = False | |
model_name: Optional[str] = None | |
class ImageTextRequest(BaseModel): | |
model_config = ConfigDict(protected_namespaces=()) | |
query: str | |
stream: bool = False | |
model_name: Optional[str] = None | |
def as_form( | |
cls, | |
query: str = Form(...), | |
stream: bool = Form(False), | |
model_name: Optional[str] = Form(None), | |
image: UploadFile = File(...) # Make image required for i2t2t | |
): | |
return cls( | |
query=query, | |
stream=stream, | |
model_name=model_name | |
), image | |
def get_client(model_name: Optional[str] = None): | |
"""Get inference client for specified model or default model""" | |
try: | |
model_path = model_name if model_name and model_name.strip() else DEFAULT_MODEL | |
return InferenceClient( | |
model=model_path | |
) | |
except Exception as e: | |
raise HTTPException( | |
status_code=400, | |
detail=f"Error initializing model {model_path}: {str(e)}" | |
) | |
def generate_text_response(query: str, model_name: Optional[str] = None): | |
messages = [{ | |
"role": "user", | |
"content": f"[SYSTEM] You are ASSISTANT who answer question asked by user in short and concise manner. [USER] {query}" | |
}] | |
try: | |
client = get_client(model_name) | |
for message in client.chat_completion( | |
messages, | |
max_tokens=2048, | |
stream=True | |
): | |
token = message.choices[0].delta.content | |
yield token | |
except Exception as e: | |
yield f"Error generating response: {str(e)}" | |
def generate_image_text_response(query: str, image_data: str, model_name: Optional[str] = None): | |
messages = [ | |
{ | |
"role": "user", | |
"content": [ | |
{"type": "text", "text": f"[SYSTEM] You are ASSISTANT who answer question asked by user in short and concise manner. [USER] {query}"}, | |
{"type": "image_url", "image_url": {"url": f"data:image/*;base64,{image_data}"}} | |
] | |
} | |
] | |
logger.debug(f"Messages sent to API: {messages}") | |
try: | |
client = get_client(model_name) | |
for message in client.chat_completion(messages, max_tokens=2048, stream=True): | |
logger.debug(f"Received message chunk: {message}") | |
token = message.choices[0].delta.content | |
yield token | |
except Exception as e: | |
logger.error(f"Error in generate_image_text_response: {str(e)}") | |
yield f"Error generating response: {str(e)}" | |
def preprocess_image(img): | |
"""Enhance image for better OCR results""" | |
# Convert to grayscale | |
img = img.convert('L') | |
# Enhance contrast | |
enhancer = ImageEnhance.Contrast(img) | |
img = enhancer.enhance(2.0) | |
# Enhance sharpness | |
enhancer = ImageEnhance.Sharpness(img) | |
img = enhancer.enhance(1.5) | |
return img | |
async def root(): | |
return {"message": "Welcome to FastAPI server!"} | |
async def text_to_text(request: TextRequest): | |
try: | |
if request.stream: | |
return StreamingResponse( | |
generate_text_response(request.query, request.model_name), | |
media_type="text/event-stream" | |
) | |
else: | |
response = "" | |
for chunk in generate_text_response(request.query, request.model_name): | |
response += chunk | |
return {"response": response} | |
except Exception as e: | |
logger.error(f"Error in /t2t endpoint: {str(e)}") | |
raise HTTPException(status_code=500, detail=str(e)) | |
async def image_text_to_text(form_data: tuple[ImageTextRequest, UploadFile] = Depends(ImageTextRequest.as_form)): | |
form, image = form_data | |
try: | |
# Process image | |
contents = await image.read() | |
try: | |
logger.debug("Attempting to open image") | |
img = Image.open(BytesIO(contents)) | |
if img.mode != 'RGB': | |
img = img.convert('RGB') | |
buffer = BytesIO() | |
img.save(buffer, format="PNG") | |
image_data = b64encode(buffer.getvalue()).decode('utf-8') | |
logger.debug("Image processed and encoded to base64") | |
except Exception as img_error: | |
logger.error(f"Error processing image: {str(img_error)}") | |
raise HTTPException( | |
status_code=422, | |
detail=f"Error processing image: {str(img_error)}" | |
) | |
if form.stream: | |
return StreamingResponse( | |
generate_image_text_response(form.query, image_data, form.model_name), | |
media_type="text/event-stream" | |
) | |
else: | |
response = "" | |
for chunk in generate_image_text_response(form.query, image_data, form.model_name): | |
response += chunk | |
return {"response": response} | |
except Exception as e: | |
logger.error(f"Error in /i2t2t endpoint: {str(e)}") | |
raise HTTPException(status_code=500, detail=str(e)) | |
async def ocr_endpoint(image: UploadFile = File(...)): | |
try: | |
# Read and process the image | |
contents = await image.read() | |
img = Image.open(BytesIO(contents)) | |
# Preprocess the image | |
img = preprocess_image(img) | |
# Perform OCR with timeout and retries | |
max_retries = 3 | |
text = "" | |
for attempt in range(max_retries): | |
try: | |
text = pytesseract.image_to_string( | |
img, | |
timeout=30, # 30 second timeout | |
config='--oem 3 --psm 6' | |
) | |
break | |
except Exception as e: | |
if attempt == max_retries - 1: | |
raise HTTPException( | |
status_code=500, | |
detail=f"Error extracting text: {str(e)}" | |
) | |
time.sleep(1) # Wait before retry | |
return {"text": text} | |
except Exception as e: | |
raise HTTPException( | |
status_code=500, | |
detail=f"Error processing image: {str(e)}" | |
) | |