Charan5775 commited on
Commit
e7d2278
·
verified ·
1 Parent(s): 72fa916

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +516 -0
app.py ADDED
@@ -0,0 +1,516 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, HTTPException, UploadFile, File, Form, Depends
2
+ from typing import Optional
3
+ from fastapi.responses import StreamingResponse, JSONResponse, HTMLResponse
4
+ from fastapi.staticfiles import StaticFiles
5
+ from fastapi.templating import Jinja2Templates
6
+ from huggingface_hub import InferenceClient
7
+ from pydantic import BaseModel, ConfigDict
8
+ import os
9
+ from base64 import b64encode
10
+ from io import BytesIO
11
+ from PIL import Image, ImageEnhance
12
+ import logging
13
+ import pytesseract
14
+ import time
15
+
16
+ # Set Tesseract CMD path for Windows
17
+ pytesseract.pytesseract.tesseract_cmd = r"F:\Python-files\tesseract\tesseract.exe"
18
+
19
+ app = FastAPI()
20
+
21
+ # Configure logging
22
+ logging.basicConfig(level=logging.DEBUG)
23
+ logger = logging.getLogger(__name__)
24
+
25
+
26
+ # Default model
27
+ DEFAULT_MODEL = "meta-llama/Meta-Llama-3-8B-Instruct"
28
+
29
+ # Initialize Jinja2 templates
30
+ templates = Jinja2Templates(directory="templates")
31
+
32
+ class TextRequest(BaseModel):
33
+ model_config = ConfigDict(protected_namespaces=())
34
+ query: str
35
+ stream: bool = False
36
+ model_name: Optional[str] = None
37
+
38
+ class ImageTextRequest(BaseModel):
39
+ model_config = ConfigDict(protected_namespaces=())
40
+ query: str
41
+ stream: bool = False
42
+ model_name: Optional[str] = None
43
+
44
+ @classmethod
45
+ def as_form(
46
+ cls,
47
+ query: str = Form(...),
48
+ stream: bool = Form(False),
49
+ model_name: Optional[str] = Form(None),
50
+ image: UploadFile = File(...) # Make image required for i2t2t
51
+ ):
52
+ return cls(
53
+ query=query,
54
+ stream=stream,
55
+ model_name=model_name
56
+ ), image
57
+
58
+ def get_client(model_name: Optional[str] = None):
59
+ """Get inference client for specified model or default model"""
60
+ try:
61
+ model_path = model_name if model_name and model_name.strip() else DEFAULT_MODEL
62
+ return InferenceClient(
63
+ model=model_path
64
+ )
65
+ except Exception as e:
66
+ raise HTTPException(
67
+ status_code=400,
68
+ detail=f"Error initializing model {model_path}: {str(e)}"
69
+ )
70
+
71
+ def generate_text_response(query: str, model_name: Optional[str] = None):
72
+ messages = [{
73
+ "role": "user",
74
+ "content": f"[SYSTEM] You are ASSISTANT who answer question asked by user in short and concise manner. [USER] {query}"
75
+ }]
76
+
77
+ try:
78
+ client = get_client(model_name)
79
+ for message in client.chat_completion(
80
+ messages,
81
+ max_tokens=2048,
82
+ stream=True
83
+ ):
84
+ token = message.choices[0].delta.content
85
+ yield token
86
+ except Exception as e:
87
+ yield f"Error generating response: {str(e)}"
88
+
89
+ def generate_image_text_response(query: str, image_data: str, model_name: Optional[str] = None):
90
+ messages = [
91
+ {
92
+ "role": "user",
93
+ "content": [
94
+ {"type": "text", "text": f"[SYSTEM] You are ASSISTANT who answer question asked by user in short and concise manner. [USER] {query}"},
95
+ {"type": "image_url", "image_url": {"url": f"data:image/*;base64,{image_data}"}}
96
+ ]
97
+ }
98
+ ]
99
+
100
+ logger.debug(f"Messages sent to API: {messages}")
101
+
102
+ try:
103
+ client = get_client(model_name)
104
+ for message in client.chat_completion(messages, max_tokens=2048, stream=True):
105
+ logger.debug(f"Received message chunk: {message}")
106
+ token = message.choices[0].delta.content
107
+ yield token
108
+ except Exception as e:
109
+ logger.error(f"Error in generate_image_text_response: {str(e)}")
110
+ yield f"Error generating response: {str(e)}"
111
+
112
+ def preprocess_image(img):
113
+ """Enhance image for better OCR results"""
114
+ # Convert to grayscale
115
+ img = img.convert('L')
116
+
117
+ # Enhance contrast
118
+ enhancer = ImageEnhance.Contrast(img)
119
+ img = enhancer.enhance(2.0)
120
+
121
+ # Enhance sharpness
122
+ enhancer = ImageEnhance.Sharpness(img)
123
+ img = enhancer.enhance(1.5)
124
+
125
+ return img
126
+
127
+ @app.get("/")
128
+ async def root():
129
+ return {"message": "Welcome to FastAPI server!"}
130
+
131
+ @app.post("/t2t")
132
+ async def text_to_text(request: TextRequest):
133
+ try:
134
+ if request.stream:
135
+ return StreamingResponse(
136
+ generate_text_response(request.query, request.model_name),
137
+ media_type="text/event-stream"
138
+ )
139
+ else:
140
+ response = ""
141
+ for chunk in generate_text_response(request.query, request.model_name):
142
+ response += chunk
143
+ return {"response": response}
144
+ except Exception as e:
145
+ logger.error(f"Error in /t2t endpoint: {str(e)}")
146
+ raise HTTPException(status_code=500, detail=str(e))
147
+
148
+ @app.post("/i2t2t")
149
+ async def image_text_to_text(form_data: tuple[ImageTextRequest, UploadFile] = Depends(ImageTextRequest.as_form)):
150
+ form, image = form_data
151
+ try:
152
+ # Process image
153
+ contents = await image.read()
154
+ try:
155
+ logger.debug("Attempting to open image")
156
+ img = Image.open(BytesIO(contents))
157
+ if img.mode != 'RGB':
158
+ img = img.convert('RGB')
159
+
160
+ buffer = BytesIO()
161
+ img.save(buffer, format="PNG")
162
+ image_data = b64encode(buffer.getvalue()).decode('utf-8')
163
+ logger.debug("Image processed and encoded to base64")
164
+ except Exception as img_error:
165
+ logger.error(f"Error processing image: {str(img_error)}")
166
+ raise HTTPException(
167
+ status_code=422,
168
+ detail=f"Error processing image: {str(img_error)}"
169
+ )
170
+
171
+ if form.stream:
172
+ return StreamingResponse(
173
+ generate_image_text_response(form.query, image_data, form.model_name),
174
+ media_type="text/event-stream"
175
+ )
176
+ else:
177
+ response = ""
178
+ for chunk in generate_image_text_response(form.query, image_data, form.model_name):
179
+ response += chunk
180
+ return {"response": response}
181
+ except Exception as e:
182
+ logger.error(f"Error in /i2t2t endpoint: {str(e)}")
183
+ raise HTTPException(status_code=500, detail=str(e))
184
+
185
+ @app.post("/tes")
186
+ async def ocr_endpoint(image: UploadFile = File(...)):
187
+ try:
188
+ # Read and process the image
189
+ contents = await image.read()
190
+ img = Image.open(BytesIO(contents))
191
+
192
+ # Preprocess the image
193
+ img = preprocess_image(img)
194
+
195
+ # Perform OCR with timeout and retries
196
+ max_retries = 3
197
+ text = ""
198
+
199
+ for attempt in range(max_retries):
200
+ try:
201
+ text = pytesseract.image_to_string(
202
+ img,
203
+ timeout=30, # 30 second timeout
204
+ config='--oem 3 --psm 6'
205
+ )
206
+ break
207
+ except Exception as e:
208
+ if attempt == max_retries - 1:
209
+ raise HTTPException(
210
+ status_code=500,
211
+ detail=f"Error extracting text: {str(e)}"
212
+ )
213
+ time.sleep(1) # Wait before retry
214
+
215
+ return {"text": text}
216
+
217
+ except Exception as e:
218
+ raise HTTPException(
219
+ status_code=500,
220
+ detail=f"Error processing image: {str(e)}"
221
+ )
222
+
223
+ @app.get("/docs/guide", response_class=HTMLResponse)
224
+ async def api_guide():
225
+ html_content = '''
226
+ <!DOCTYPE html>
227
+ <html lang="en">
228
+ <head>
229
+ <meta charset="UTF-8">
230
+ <meta name="viewport" content="width=device-width, initial-scale=1.0">
231
+ <title>API Documentation</title>
232
+ <link href="https://cdn.jsdelivr.net/npm/[email protected]/dist/tailwind.min.css" rel="stylesheet">
233
+ <link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/prism/1.24.1/themes/prism-tomorrow.min.css">
234
+ <style>
235
+ .copy-button {
236
+ position: absolute;
237
+ top: 8px;
238
+ right: 8px;
239
+ padding: 4px 8px;
240
+ background: #2d3748;
241
+ border: 1px solid #4a5568;
242
+ border-radius: 4px;
243
+ color: #cbd5e0;
244
+ font-size: 12px;
245
+ cursor: pointer;
246
+ transition: all 0.2s;
247
+ }
248
+ .copy-button:hover {
249
+ background: #4a5568;
250
+ }
251
+ .code-block {
252
+ position: relative;
253
+ margin: 1rem 0;
254
+ }
255
+ .endpoint-card {
256
+ background: #1a202c;
257
+ border-radius: 8px;
258
+ margin-bottom: 2rem;
259
+ padding: 1.5rem;
260
+ }
261
+ .language-tab {
262
+ cursor: pointer;
263
+ padding: 0.5rem 1rem;
264
+ border-radius: 4px 4px 0 0;
265
+ }
266
+ .language-tab.active {
267
+ background: #2d3748;
268
+ color: #fff;
269
+ }
270
+ </style>
271
+ </head>
272
+ <body class="bg-gray-900 text-gray-100 min-h-screen p-8">
273
+ <div class="max-w-6xl mx-auto">
274
+ <h1 class="text-4xl font-bold mb-8">API Documentation</h1>
275
+
276
+ <!-- T2T Endpoint -->
277
+ <div class="endpoint-card">
278
+ <h2 class="text-2xl font-semibold mb-4">Text-to-Text Endpoint</h2>
279
+ <p class="mb-4 text-gray-400">Endpoint for general text queries</p>
280
+ <p class="mb-2 text-gray-300"><span class="font-mono bg-gray-800 px-2 py-1 rounded">POST /t2t</span></p>
281
+
282
+ <div class="code-block">
283
+ <div class="flex mb-2">
284
+ <div class="language-tab active" data-lang="curl">cURL</div>
285
+ <div class="language-tab" data-lang="python">Python</div>
286
+ <div class="language-tab" data-lang="javascript">JavaScript</div>
287
+ <div class="language-tab" data-lang="node">Node.js</div>
288
+ </div>
289
+ <pre><code class="language-bash">curl -X POST "http://localhost:8000/t2t" \
290
+ -H "Content-Type: application/json" \
291
+ -d '{"query": "What is FastAPI?", "stream": false}'</code></pre>
292
+ <button class="copy-button">Copy</button>
293
+ </div>
294
+ </div>
295
+
296
+ <!-- I2T2T Endpoint -->
297
+ <div class="endpoint-card">
298
+ <h2 class="text-2xl font-semibold mb-4">Image and Text to Text Endpoint</h2>
299
+ <p class="mb-4 text-gray-400">Endpoint for queries about images</p>
300
+ <p class="mb-2 text-gray-300"><span class="font-mono bg-gray-800 px-2 py-1 rounded">POST /i2t2t</span></p>
301
+
302
+ <div class="code-block">
303
+ <div class="flex mb-2">
304
+ <div class="language-tab active" data-lang="curl">cURL</div>
305
+ <div class="language-tab" data-lang="python">Python</div>
306
+ <div class="language-tab" data-lang="javascript">JavaScript</div>
307
+ <div class="language-tab" data-lang="node">Node.js</div>
308
+ </div>
309
+ <pre><code class="language-bash">curl -X POST "http://localhost:8000/i2t2t" \
310
+ -F "query=Describe this image" \
311
+ -F "stream=false" \
312
+ -F "image=@/path/to/your/image.jpg"</code></pre>
313
+ <button class="copy-button">Copy</button>
314
+ </div>
315
+ </div>
316
+
317
+ <!-- TES Endpoint -->
318
+ <div class="endpoint-card">
319
+ <h2 class="text-2xl font-semibold mb-4">OCR Endpoint</h2>
320
+ <p class="mb-4 text-gray-400">Extract text from images using OCR</p>
321
+ <p class="mb-2 text-gray-300"><span class="font-mono bg-gray-800 px-2 py-1 rounded">POST /tes</span></p>
322
+
323
+ <div class="code-block">
324
+ <div class="flex mb-2">
325
+ <div class="language-tab active" data-lang="curl">cURL</div>
326
+ <div class="language-tab" data-lang="python">Python</div>
327
+ <div class="language-tab" data-lang="javascript">JavaScript</div>
328
+ <div class="language-tab" data-lang="node">Node.js</div>
329
+ </div>
330
+ <pre><code class="language-bash">curl -X POST "http://localhost:8000/tes" \
331
+ -F "image=@/path/to/your/image.jpg"</code></pre>
332
+ <button class="copy-button">Copy</button>
333
+ </div>
334
+ </div>
335
+ </div>
336
+
337
+ <script src="https://cdnjs.cloudflare.com/ajax/libs/prism/1.24.1/prism.min.js"></script>
338
+ <script src="https://cdnjs.cloudflare.com/ajax/libs/prism/1.24.1/components/prism-python.min.js"></script>
339
+ <script src="https://cdnjs.cloudflare.com/ajax/libs/prism/1.24.1/components/prism-javascript.min.js"></script>
340
+ <script src="https://cdnjs.cloudflare.com/ajax/libs/prism/1.24.1/components/prism-bash.min.js"></script>
341
+ <script>
342
+ const codeExamples = {
343
+ 't2t': {
344
+ 'curl': `curl -X POST "http://localhost:8000/t2t" \\
345
+ -H "Content-Type: application/json" \\
346
+ -d '{"query": "What is FastAPI?", "stream": false}'`,
347
+ 'python': `import requests
348
+
349
+ url = "http://localhost:8000/t2t"
350
+ payload = {
351
+ "query": "What is FastAPI?",
352
+ "stream": False
353
+ }
354
+ response = requests.post(url, json=payload)
355
+ print(response.json())`,
356
+ 'javascript': `// Using fetch
357
+ fetch("http://localhost:8000/t2t", {
358
+ method: "POST",
359
+ headers: {
360
+ "Content-Type": "application/json",
361
+ },
362
+ body: JSON.stringify({
363
+ query: "What is FastAPI?",
364
+ stream: false
365
+ })
366
+ })
367
+ .then(response => response.json())
368
+ .then(data => console.log(data));`,
369
+ 'node': `const axios = require('axios');
370
+
371
+ async function makeRequest() {
372
+ try {
373
+ const response = await axios.post('http://localhost:8000/t2t', {
374
+ query: "What is FastAPI?",
375
+ stream: false
376
+ });
377
+ console.log(response.data);
378
+ } catch (error) {
379
+ console.error(error);
380
+ }
381
+ }
382
+
383
+ makeRequest();`
384
+ },
385
+ 'i2t2t': {
386
+ 'curl': `curl -X POST "http://localhost:8000/i2t2t" \\
387
+ -F "query=Describe this image" \\
388
+ -F "stream=false" \\
389
+ -F "image=@/path/to/your/image.jpg"`,
390
+ 'python': `import requests
391
+
392
+ url = "http://localhost:8000/i2t2t"
393
+ files = {
394
+ 'image': ('image.jpg', open('path/to/image.jpg', 'rb')),
395
+ }
396
+ data = {
397
+ 'query': 'Describe this image',
398
+ 'stream': 'false'
399
+ }
400
+ response = requests.post(url, files=files, data=data)
401
+ print(response.json())`,
402
+ 'javascript': `const formData = new FormData();
403
+ formData.append('image', imageFile);
404
+ formData.append('query', 'Describe this image');
405
+ formData.append('stream', 'false');
406
+
407
+ fetch("http://localhost:8000/i2t2t", {
408
+ method: "POST",
409
+ body: formData
410
+ })
411
+ .then(response => response.json())
412
+ .then(data => console.log(data));`,
413
+ 'node': `const axios = require('axios');
414
+ const FormData = require('form-data');
415
+ const fs = require('fs');
416
+
417
+ async function makeRequest() {
418
+ try {
419
+ const formData = new FormData();
420
+ formData.append('image', fs.createReadStream('path/to/image.jpg'));
421
+ formData.append('query', 'Describe this image');
422
+ formData.append('stream', 'false');
423
+
424
+ const response = await axios.post('http://localhost:8000/i2t2t', formData, {
425
+ headers: formData.getHeaders()
426
+ });
427
+ console.log(response.data);
428
+ } catch (error) {
429
+ console.error(error);
430
+ }
431
+ }
432
+
433
+ makeRequest();`
434
+ },
435
+ 'tes': {
436
+ 'curl': `curl -X POST "http://localhost:8000/tes" \\
437
+ -F "image=@/path/to/your/image.jpg"`,
438
+ 'python': `import requests
439
+
440
+ url = "http://localhost:8000/tes"
441
+ files = {
442
+ 'image': ('image.jpg', open('path/to/image.jpg', 'rb'))
443
+ }
444
+ response = requests.post(url, files=files)
445
+ print(response.json())`,
446
+ 'javascript': `const formData = new FormData();
447
+ formData.append('image', imageFile);
448
+
449
+ fetch("http://localhost:8000/tes", {
450
+ method: "POST",
451
+ body: formData
452
+ })
453
+ .then(response => response.json())
454
+ .then(data => console.log(data));`,
455
+ 'node': `const axios = require('axios');
456
+ const FormData = require('form-data');
457
+ const fs = require('fs');
458
+
459
+ async function makeRequest() {
460
+ try {
461
+ const formData = new FormData();
462
+ formData.append('image', fs.createReadStream('path/to/image.jpg'));
463
+
464
+ const response = await axios.post('http://localhost:8000/tes', formData, {
465
+ headers: formData.getHeaders()
466
+ });
467
+ console.log(response.data);
468
+ } catch (error) {
469
+ console.error(error);
470
+ }
471
+ }
472
+
473
+ makeRequest();`
474
+ }
475
+ };
476
+
477
+ // Handle language tab switching
478
+ document.querySelectorAll('.language-tab').forEach(tab => {
479
+ tab.addEventListener('click', () => {
480
+ const lang = tab.dataset.lang;
481
+ const codeBlock = tab.closest('.endpoint-card');
482
+ const endpoint = codeBlock.querySelector('h2').textContent.toLowerCase().includes('ocr') ? 'tes' :
483
+ codeBlock.querySelector('h2').textContent.toLowerCase().includes('image') ? 'i2t2t' : 't2t';
484
+
485
+ // Update active tab
486
+ codeBlock.querySelectorAll('.language-tab').forEach(t => t.classList.remove('active'));
487
+ tab.classList.add('active');
488
+
489
+ // Update code content
490
+ const code = codeBlock.querySelector('code');
491
+ code.textContent = codeExamples[endpoint][lang];
492
+ code.className = `language-${lang === 'curl' ? 'bash' : lang}`;
493
+ Prism.highlightElement(code);
494
+ });
495
+ });
496
+
497
+ // Handle copy buttons
498
+ document.querySelectorAll('.copy-button').forEach(button => {
499
+ button.addEventListener('click', () => {
500
+ const code = button.previousElementSibling.textContent;
501
+ navigator.clipboard.writeText(code);
502
+
503
+ // Show feedback
504
+ const originalText = button.textContent;
505
+ button.textContent = 'Copied!';
506
+ setTimeout(() => {
507
+ button.textContent = originalText;
508
+ }, 2000);
509
+ });
510
+ });
511
+ </script>
512
+ </body>
513
+ </html>
514
+ '''
515
+ return HTMLResponse(content=html_content)
516
+