testingtodeploy / app_old.py
Charan5775's picture
Rename app.py to app_old.py
72fa916 verified
from fastapi import FastAPI, HTTPException, UploadFile, File, Form, Depends
from typing import Optional
from fastapi.responses import StreamingResponse
from huggingface_hub import InferenceClient
from pydantic import BaseModel, ConfigDict
import os
from base64 import b64encode
from io import BytesIO
from PIL import Image, ImageEnhance
import logging
import pytesseract
import time
app = FastAPI()
# Configure logging
logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger(__name__)
# Default model
DEFAULT_MODEL = "meta-llama/Meta-Llama-3-8B-Instruct"
class TextRequest(BaseModel):
model_config = ConfigDict(protected_namespaces=())
query: str
stream: bool = False
model_name: Optional[str] = None
class ImageTextRequest(BaseModel):
model_config = ConfigDict(protected_namespaces=())
query: str
stream: bool = False
model_name: Optional[str] = None
@classmethod
def as_form(
cls,
query: str = Form(...),
stream: bool = Form(False),
model_name: Optional[str] = Form(None),
image: UploadFile = File(...) # Make image required for i2t2t
):
return cls(
query=query,
stream=stream,
model_name=model_name
), image
def get_client(model_name: Optional[str] = None):
"""Get inference client for specified model or default model"""
try:
model_path = model_name if model_name and model_name.strip() else DEFAULT_MODEL
return InferenceClient(
model=model_path
)
except Exception as e:
raise HTTPException(
status_code=400,
detail=f"Error initializing model {model_path}: {str(e)}"
)
def generate_text_response(query: str, model_name: Optional[str] = None):
messages = [{
"role": "user",
"content": f"[SYSTEM] You are ASSISTANT who answer question asked by user in short and concise manner. [USER] {query}"
}]
try:
client = get_client(model_name)
for message in client.chat_completion(
messages,
max_tokens=2048,
stream=True
):
token = message.choices[0].delta.content
yield token
except Exception as e:
yield f"Error generating response: {str(e)}"
def generate_image_text_response(query: str, image_data: str, model_name: Optional[str] = None):
messages = [
{
"role": "user",
"content": [
{"type": "text", "text": f"[SYSTEM] You are ASSISTANT who answer question asked by user in short and concise manner. [USER] {query}"},
{"type": "image_url", "image_url": {"url": f"data:image/*;base64,{image_data}"}}
]
}
]
logger.debug(f"Messages sent to API: {messages}")
try:
client = get_client(model_name)
for message in client.chat_completion(messages, max_tokens=2048, stream=True):
logger.debug(f"Received message chunk: {message}")
token = message.choices[0].delta.content
yield token
except Exception as e:
logger.error(f"Error in generate_image_text_response: {str(e)}")
yield f"Error generating response: {str(e)}"
def preprocess_image(img):
"""Enhance image for better OCR results"""
# Convert to grayscale
img = img.convert('L')
# Enhance contrast
enhancer = ImageEnhance.Contrast(img)
img = enhancer.enhance(2.0)
# Enhance sharpness
enhancer = ImageEnhance.Sharpness(img)
img = enhancer.enhance(1.5)
return img
@app.get("/")
async def root():
return {"message": "Welcome to FastAPI server!"}
@app.post("/t2t")
async def text_to_text(request: TextRequest):
try:
if request.stream:
return StreamingResponse(
generate_text_response(request.query, request.model_name),
media_type="text/event-stream"
)
else:
response = ""
for chunk in generate_text_response(request.query, request.model_name):
response += chunk
return {"response": response}
except Exception as e:
logger.error(f"Error in /t2t endpoint: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
@app.post("/i2t2t")
async def image_text_to_text(form_data: tuple[ImageTextRequest, UploadFile] = Depends(ImageTextRequest.as_form)):
form, image = form_data
try:
# Process image
contents = await image.read()
try:
logger.debug("Attempting to open image")
img = Image.open(BytesIO(contents))
if img.mode != 'RGB':
img = img.convert('RGB')
buffer = BytesIO()
img.save(buffer, format="PNG")
image_data = b64encode(buffer.getvalue()).decode('utf-8')
logger.debug("Image processed and encoded to base64")
except Exception as img_error:
logger.error(f"Error processing image: {str(img_error)}")
raise HTTPException(
status_code=422,
detail=f"Error processing image: {str(img_error)}"
)
if form.stream:
return StreamingResponse(
generate_image_text_response(form.query, image_data, form.model_name),
media_type="text/event-stream"
)
else:
response = ""
for chunk in generate_image_text_response(form.query, image_data, form.model_name):
response += chunk
return {"response": response}
except Exception as e:
logger.error(f"Error in /i2t2t endpoint: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
@app.post("/tes")
async def ocr_endpoint(image: UploadFile = File(...)):
try:
# Read and process the image
contents = await image.read()
img = Image.open(BytesIO(contents))
# Preprocess the image
img = preprocess_image(img)
# Perform OCR with timeout and retries
max_retries = 3
text = ""
for attempt in range(max_retries):
try:
text = pytesseract.image_to_string(
img,
timeout=30, # 30 second timeout
config='--oem 3 --psm 6'
)
break
except Exception as e:
if attempt == max_retries - 1:
raise HTTPException(
status_code=500,
detail=f"Error extracting text: {str(e)}"
)
time.sleep(1) # Wait before retry
return {"text": text}
except Exception as e:
raise HTTPException(
status_code=500,
detail=f"Error processing image: {str(e)}"
)