Charan5775 commited on
Commit
1537a05
·
verified ·
1 Parent(s): 6300dd3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +126 -69
app.py CHANGED
@@ -6,8 +6,10 @@ from pydantic import BaseModel, ConfigDict
6
  import os
7
  from base64 import b64encode
8
  from io import BytesIO
9
- from PIL import Image # Add this import
10
  import logging
 
 
11
 
12
  app = FastAPI()
13
 
@@ -15,22 +17,18 @@ app = FastAPI()
15
  logging.basicConfig(level=logging.DEBUG)
16
  logger = logging.getLogger(__name__)
17
 
18
- # Get HuggingFace token from environment variable
19
 
20
  # Default model
21
  DEFAULT_MODEL = "meta-llama/Meta-Llama-3-8B-Instruct"
22
 
23
- class QueryRequest(BaseModel):
24
  model_config = ConfigDict(protected_namespaces=())
25
-
26
  query: str
27
- image_data: Optional[str] = None # Base64 encoded image data
28
  stream: bool = False
29
  model_name: Optional[str] = None
30
 
31
- class ChatForm(BaseModel):
32
  model_config = ConfigDict(protected_namespaces=())
33
-
34
  query: str
35
  stream: bool = False
36
  model_name: Optional[str] = None
@@ -41,7 +39,7 @@ class ChatForm(BaseModel):
41
  query: str = Form(...),
42
  stream: bool = Form(False),
43
  model_name: Optional[str] = Form(None),
44
- image: Optional[UploadFile] = File(None)
45
  ):
46
  return cls(
47
  query=query,
@@ -52,9 +50,7 @@ class ChatForm(BaseModel):
52
  def get_client(model_name: Optional[str] = None):
53
  """Get inference client for specified model or default model"""
54
  try:
55
- # Use provided model_name if it exists and is not empty, otherwise use DEFAULT_MODEL
56
  model_path = model_name if model_name and model_name.strip() else DEFAULT_MODEL
57
-
58
  return InferenceClient(
59
  model=model_path
60
  )
@@ -64,26 +60,11 @@ def get_client(model_name: Optional[str] = None):
64
  detail=f"Error initializing model {model_path}: {str(e)}"
65
  )
66
 
67
- def generate_response(query: str, image_data: Optional[str] = None, model_name: Optional[str] = None):
68
- messages = []
69
-
70
- # Create the system and user message
71
- user_content = f"[SYSTEM] You are ASSISTANT who answer question asked by user in short and concise manner. [USER] {query}"
72
-
73
- # If there's an image, add it to the message
74
- if image_data:
75
- messages.append({
76
- "role": "user",
77
- "content": [
78
- {"type": "text", "text": user_content},
79
- {"type": "image_url", "image_url": {"url": f"data:image/*;base64,{image_data}"}}
80
- ]
81
- })
82
- else:
83
- messages.append({
84
- "role": "user",
85
- "content": user_content
86
- })
87
 
88
  try:
89
  client = get_client(model_name)
@@ -97,61 +78,137 @@ def generate_response(query: str, image_data: Optional[str] = None, model_name:
97
  except Exception as e:
98
  yield f"Error generating response: {str(e)}"
99
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
100
  @app.get("/")
101
  async def root():
102
  return {"message": "Welcome to FastAPI server!"}
103
 
104
- @app.post("/chat")
105
- async def chat(form_data: tuple[ChatForm, Optional[UploadFile]] = Depends(ChatForm.as_form)):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
106
  form, image = form_data
107
  try:
108
- image_data = None
109
- if image:
110
- logger.debug("Image received")
111
- # Read the image
112
- contents = await image.read()
 
 
113
 
114
- # Convert image to appropriate format if needed
115
- try:
116
- logger.debug("Attempting to open image")
117
- img = Image.open(BytesIO(contents))
118
- logger.debug(f"Image format before conversion: {img.format}, mode: {img.mode}")
119
- # Convert to RGB if needed
120
- if img.mode != 'RGB':
121
- img = img.convert('RGB')
122
- logger.debug(f"Image format after conversion: {img.format}, mode: {img.mode}")
123
-
124
- # Save as JPEG in memory
125
- buffer = BytesIO()
126
- img.save(buffer, format="JPEG")
127
- image_data = b64encode(buffer.getvalue()).decode('utf-8')
128
- logger.debug("Image processed and encoded to base64")
129
- except Exception as img_error:
130
- logger.error(f"Error processing image: {str(img_error)}")
131
- raise HTTPException(
132
- status_code=422,
133
- detail=f"Error processing image: {str(img_error)}"
134
- )
135
 
136
  if form.stream:
137
  return StreamingResponse(
138
- generate_response(form.query, image_data, form.model_name),
139
  media_type="text/event-stream"
140
  )
141
  else:
142
  response = ""
143
- for chunk in generate_response(form.query, image_data, form.model_name):
144
  response += chunk
145
  return {"response": response}
146
  except Exception as e:
147
- logger.error(f"Error in /chat endpoint: {str(e)}")
148
  raise HTTPException(status_code=500, detail=str(e))
149
 
150
- if __name__ == "__main__":
151
- import uvicorn
152
- uvicorn.run(
153
- "main:app",
154
- port=8000,
155
- reload=True, # Enable auto-reload
156
- reload_dirs=["./"] # Watch the current directory for changes
157
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
  import os
7
  from base64 import b64encode
8
  from io import BytesIO
9
+ from PIL import Image, ImageEnhance
10
  import logging
11
+ import pytesseract
12
+ import time
13
 
14
  app = FastAPI()
15
 
 
17
  logging.basicConfig(level=logging.DEBUG)
18
  logger = logging.getLogger(__name__)
19
 
 
20
 
21
  # Default model
22
  DEFAULT_MODEL = "meta-llama/Meta-Llama-3-8B-Instruct"
23
 
24
+ class TextRequest(BaseModel):
25
  model_config = ConfigDict(protected_namespaces=())
 
26
  query: str
 
27
  stream: bool = False
28
  model_name: Optional[str] = None
29
 
30
+ class ImageTextRequest(BaseModel):
31
  model_config = ConfigDict(protected_namespaces=())
 
32
  query: str
33
  stream: bool = False
34
  model_name: Optional[str] = None
 
39
  query: str = Form(...),
40
  stream: bool = Form(False),
41
  model_name: Optional[str] = Form(None),
42
+ image: UploadFile = File(...) # Make image required for i2t2t
43
  ):
44
  return cls(
45
  query=query,
 
50
  def get_client(model_name: Optional[str] = None):
51
  """Get inference client for specified model or default model"""
52
  try:
 
53
  model_path = model_name if model_name and model_name.strip() else DEFAULT_MODEL
 
54
  return InferenceClient(
55
  model=model_path
56
  )
 
60
  detail=f"Error initializing model {model_path}: {str(e)}"
61
  )
62
 
63
+ def generate_text_response(query: str, model_name: Optional[str] = None):
64
+ messages = [{
65
+ "role": "user",
66
+ "content": f"[SYSTEM] You are ASSISTANT who answer question asked by user in short and concise manner. [USER] {query}"
67
+ }]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
 
69
  try:
70
  client = get_client(model_name)
 
78
  except Exception as e:
79
  yield f"Error generating response: {str(e)}"
80
 
81
+ def generate_image_text_response(query: str, image_data: str, model_name: Optional[str] = None):
82
+ messages = [
83
+ {
84
+ "role": "user",
85
+ "content": [
86
+ {"type": "text", "text": f"[SYSTEM] You are ASSISTANT who answer question asked by user in short and concise manner. [USER] {query}"},
87
+ {"type": "image_url", "image_url": {"url": f"data:image/*;base64,{image_data}"}}
88
+ ]
89
+ }
90
+ ]
91
+
92
+ logger.debug(f"Messages sent to API: {messages}")
93
+
94
+ try:
95
+ client = get_client(model_name)
96
+ for message in client.chat_completion(messages, max_tokens=2048, stream=True):
97
+ logger.debug(f"Received message chunk: {message}")
98
+ token = message.choices[0].delta.content
99
+ yield token
100
+ except Exception as e:
101
+ logger.error(f"Error in generate_image_text_response: {str(e)}")
102
+ yield f"Error generating response: {str(e)}"
103
+
104
+ def preprocess_image(img):
105
+ """Enhance image for better OCR results"""
106
+ # Convert to grayscale
107
+ img = img.convert('L')
108
+
109
+ # Enhance contrast
110
+ enhancer = ImageEnhance.Contrast(img)
111
+ img = enhancer.enhance(2.0)
112
+
113
+ # Enhance sharpness
114
+ enhancer = ImageEnhance.Sharpness(img)
115
+ img = enhancer.enhance(1.5)
116
+
117
+ return img
118
+
119
  @app.get("/")
120
  async def root():
121
  return {"message": "Welcome to FastAPI server!"}
122
 
123
+ @app.post("/t2t")
124
+ async def text_to_text(request: TextRequest):
125
+ try:
126
+ if request.stream:
127
+ return StreamingResponse(
128
+ generate_text_response(request.query, request.model_name),
129
+ media_type="text/event-stream"
130
+ )
131
+ else:
132
+ response = ""
133
+ for chunk in generate_text_response(request.query, request.model_name):
134
+ response += chunk
135
+ return {"response": response}
136
+ except Exception as e:
137
+ logger.error(f"Error in /t2t endpoint: {str(e)}")
138
+ raise HTTPException(status_code=500, detail=str(e))
139
+
140
+ @app.post("/i2t2t")
141
+ async def image_text_to_text(form_data: tuple[ImageTextRequest, UploadFile] = Depends(ImageTextRequest.as_form)):
142
  form, image = form_data
143
  try:
144
+ # Process image
145
+ contents = await image.read()
146
+ try:
147
+ logger.debug("Attempting to open image")
148
+ img = Image.open(BytesIO(contents))
149
+ if img.mode != 'RGB':
150
+ img = img.convert('RGB')
151
 
152
+ buffer = BytesIO()
153
+ img.save(buffer, format="PNG")
154
+ image_data = b64encode(buffer.getvalue()).decode('utf-8')
155
+ logger.debug("Image processed and encoded to base64")
156
+ except Exception as img_error:
157
+ logger.error(f"Error processing image: {str(img_error)}")
158
+ raise HTTPException(
159
+ status_code=422,
160
+ detail=f"Error processing image: {str(img_error)}"
161
+ )
 
 
 
 
 
 
 
 
 
 
 
162
 
163
  if form.stream:
164
  return StreamingResponse(
165
+ generate_image_text_response(form.query, image_data, form.model_name),
166
  media_type="text/event-stream"
167
  )
168
  else:
169
  response = ""
170
+ for chunk in generate_image_text_response(form.query, image_data, form.model_name):
171
  response += chunk
172
  return {"response": response}
173
  except Exception as e:
174
+ logger.error(f"Error in /i2t2t endpoint: {str(e)}")
175
  raise HTTPException(status_code=500, detail=str(e))
176
 
177
+ @app.post("/tes")
178
+ async def ocr_endpoint(image: UploadFile = File(...)):
179
+ try:
180
+ # Read and process the image
181
+ contents = await image.read()
182
+ img = Image.open(BytesIO(contents))
183
+
184
+ # Preprocess the image
185
+ img = preprocess_image(img)
186
+
187
+ # Perform OCR with timeout and retries
188
+ max_retries = 3
189
+ text = ""
190
+
191
+ for attempt in range(max_retries):
192
+ try:
193
+ text = pytesseract.image_to_string(
194
+ img,
195
+ timeout=30, # 30 second timeout
196
+ config='--oem 3 --psm 6'
197
+ )
198
+ break
199
+ except Exception as e:
200
+ if attempt == max_retries - 1:
201
+ raise HTTPException(
202
+ status_code=500,
203
+ detail=f"Error extracting text: {str(e)}"
204
+ )
205
+ time.sleep(1) # Wait before retry
206
+
207
+ return {"text": text}
208
+
209
+ except Exception as e:
210
+ raise HTTPException(
211
+ status_code=500,
212
+ detail=f"Error processing image: {str(e)}"
213
+ )
214
+