from fastapi import FastAPI, File, UploadFile, HTTPException from fastapi.middleware.cors import CORSMiddleware from typing import Dict import os import shutil import logging import torch from transformers import LayoutLMv3Processor, LayoutLMv3ForTokenClassification from dotenv import load_dotenv import os from utils import doc_processing # Load .env file load_dotenv() # Access variables dummy_key = os.getenv("dummy_key") HUGGINGFACE_AUTH_TOKEN = dummy_key # Hugging Face model and token aadhar_model = "AuditEdge/doc_ocr_a" # Replace with your fine-tuned model if applicable device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"Using device: {device}") # Load the processor (tokenizer + image processor) processor_aadhar = LayoutLMv3Processor.from_pretrained( aadhar_model, use_auth_token=HUGGINGFACE_AUTH_TOKEN ) aadhar_model = LayoutLMv3ForTokenClassification.from_pretrained( aadhar_model, use_auth_token=HUGGINGFACE_AUTH_TOKEN ) aadhar_model = aadhar_model.to(device) # pan model pan_model = "AuditEdge/doc_ocr_p" # Replace with your fine-tuned model if applicable device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"Using device: {device}") # Load the processor (tokenizer + image processor) processor_pan = LayoutLMv3Processor.from_pretrained( pan_model, use_auth_token=HUGGINGFACE_AUTH_TOKEN ) pan_model = LayoutLMv3ForTokenClassification.from_pretrained( pan_model, use_auth_token=HUGGINGFACE_AUTH_TOKEN ) pan_model = pan_model.to(device) # # gst model gst_model = "AuditEdge/doc_ocr_new_g" # Replace with your fine-tuned model if applicable device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"Using device: {device}") # Load the processor (tokenizer + image processor) processor_gst = LayoutLMv3Processor.from_pretrained( gst_model, use_auth_token=HUGGINGFACE_AUTH_TOKEN ) gst_model = LayoutLMv3ForTokenClassification.from_pretrained( gst_model, use_auth_token=HUGGINGFACE_AUTH_TOKEN ) gst_model = gst_model.to(device) #cheque model cheque_model = "AuditEdge/doc_ocr_new_c" # Replace with your fine-tuned model if applicable device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"Using device: {device}") # Load the processor (tokenizer + image processor) processor_cheque = LayoutLMv3Processor.from_pretrained( cheque_model, use_auth_token=HUGGINGFACE_AUTH_TOKEN ) cheque_model = LayoutLMv3ForTokenClassification.from_pretrained( cheque_model, use_auth_token=HUGGINGFACE_AUTH_TOKEN ) cheque_model = cheque_model.to(device) # Verify model and processor are loaded print("Model and processor loaded successfully!") print(f"Model is on device: {next(aadhar_model.parameters()).device}") # Import inference modules from layoutlmv3FineTuning.Layoutlm_inference.ocr import prepare_batch_for_inference from layoutlmv3FineTuning.Layoutlm_inference.inference_handler import handle # Create FastAPI instance app = FastAPI(debug=True) # Enable CORS app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # Configure directories UPLOAD_FOLDER = './uploads/' processing_folder = "./processed_images" os.makedirs(UPLOAD_FOLDER, exist_ok=True) # Ensure the main upload folder exists os.makedirs(processing_folder,exist_ok=True) UPLOAD_DIRS = { "aadhar_file": "uploads/aadhar/", "pan_file": "uploads/pan/", "cheque_file": "uploads/cheque/", "gst_file": "uploads/gst/", } process_dirs = { "aadhar_file": "processed_images/aadhar/", "pan_file": "processed_images/pan/", "cheque_file": "processed_images/cheque/", "gst_file": "processed_images/gst/", } # Ensure individual directories exist for dir_path in UPLOAD_DIRS.values(): os.makedirs(dir_path, exist_ok=True) for dir_path in process_dirs.values(): os.makedirs(dir_path, exist_ok=True) # Logger configuration logging.basicConfig(level=logging.INFO) # Perform Inference def perform_inference(file_paths: Dict[str, str]): # Dictionary to map document types to their respective model directories model_dirs = { "aadhar_file": aadhar_model, "pan_file": pan_model, "cheque_file": cheque_model, "gst_file": gst_model, } # Dictionary to store results for each document type inference_results = {} # Loop through the file paths and perform inference for doc_type, file_path in file_paths.items(): if doc_type in model_dirs: print(f"Processing {doc_type} using model at {model_dirs[doc_type]}") # Prepare batch for inference images_path = [file_path] inference_batch = prepare_batch_for_inference(images_path) # Prepare context for the specific document type # context = {"model_dir": model_dirs[doc_type]} # context = aadhar_model if doc_type == "aadhar_file": context = aadhar_model processor = processor_aadhar name = "aadhar" attachemnt_num = 3 if doc_type == "pan_file": context = pan_model processor = processor_pan name = "pan" attachemnt_num = 2 if doc_type == "gst_file": context = gst_model processor = processor_gst name = "gst" attachemnt_num = 4 if doc_type == "cheque_file": context = cheque_model processor = processor_cheque name = "cheque" attachemnt_num = 8 # Perform inference (replace `handle` with your actual function) result = handle(inference_batch, context,processor,name) # Store the result inference_results["attachment_{}".format(attachemnt_num)] = result else: print(f"Model directory not found for {doc_type}. Skipping.") return inference_results # Routes @app.get("/") def greet_json(): return {"Hello": "World!"} @app.post("/api/aadhar_ocr") async def aadhar_ocr( aadhar_file: UploadFile = File(None), pan_file: UploadFile = File(None), cheque_file: UploadFile = File(None), gst_file: UploadFile = File(None), ): try: # Handle file uploads file_paths = {} for file_type, folder in UPLOAD_DIRS.items(): file = locals()[file_type] # Dynamically access the file arguments if file: # Save the file in the respective directory file_path = os.path.join(folder, file.filename) with open(file_path, "wb") as buffer: shutil.copyfileobj(file.file, buffer) file_paths[file_type] = file_path # Log received files logging.info(f"Received files: {list(file_paths.keys())}") print("file_paths",file_paths) files = {} for key, value in file_paths.items(): name = value.split("/")[-1].split(".")[0] id_type = key.split("_")[0] doc_type = value.split("/")[-1].split(".")[1] f_path = value preprocessing = doc_processing(name,id_type,doc_type,f_path) response = preprocessing.process() files[key] = response["output_p"] print("response",response) # Perform inference result = perform_inference(files) return {"status": "success", "result": result} except Exception as e: logging.error(f"Error processing files: {e}") # raise HTTPException(status_code=500, detail="Internal Server Error") return {"status":400}