|
import base64 |
|
import json |
|
import os |
|
from google.oauth2 import service_account |
|
import vertexai |
|
from remittance_pdf_processing_utils import remittance_logger |
|
from vertexai.generative_models import GenerativeModel, Part |
|
import vertexai.preview.generative_models as generative_models |
|
from remittance_pdf_processing_types import InvoiceNumbers,PaymentAmount |
|
from remittance_pdf_processing_utils import remove_duplicate_lists |
|
|
|
|
|
def initialize_vertexai(): |
|
|
|
encoded_sa_json = os.environ.get('VERTEX_AI_SERVICE_ACCOUNT_JSON') |
|
|
|
if not encoded_sa_json: |
|
raise ValueError("VERTEX_AI_SERVICE_ACCOUNT_JSON environment variable is not set") |
|
|
|
try: |
|
|
|
sa_json_str = base64.b64decode(encoded_sa_json).decode('utf-8') |
|
sa_info = json.loads(sa_json_str) |
|
|
|
|
|
credentials = service_account.Credentials.from_service_account_info( |
|
sa_info, |
|
scopes=['https://www.googleapis.com/auth/cloud-platform'] |
|
) |
|
|
|
|
|
vertexai.init(project="saltech-ai-sandbox", location="us-central1", credentials=credentials) |
|
|
|
print("Vertex AI initialized successfully.") |
|
except json.JSONDecodeError: |
|
raise ValueError("Invalid JSON format in the decoded service account information") |
|
except Exception as e: |
|
raise Exception(f"Error initializing Vertex AI: {str(e)}") |
|
|
|
|
|
|
|
initialize_vertexai() |
|
|
|
def extract_invoice_numbers_with_vertex_ai(base64_image: str, multi_hop: bool = False) -> list[InvoiceNumbers]: |
|
""" |
|
Dispatches the invoice number extraction to either single-hop or multi-hop method based on the multi_hop parameter. |
|
|
|
Args: |
|
base64_image (str): The base64-encoded image string. |
|
multi_hop (bool): Whether to use multi-hop processing. |
|
|
|
Returns: |
|
list[InvoiceNumbers]: A list containing lists of extracted invoice numbers. |
|
""" |
|
if multi_hop: |
|
return extract_invoice_numbers_with_vertex_ai_multi_hop(base64_image) |
|
else: |
|
return extract_invoice_numbers_with_vertex_ai_single_hop(base64_image) |
|
|
|
def extract_invoice_numbers_with_vertex_ai_single_hop(base64_image: str) -> list[InvoiceNumbers]: |
|
""" |
|
Extracts invoice numbers from a single base64-encoded image using Google's Gemini Flash model with single-hop processing. |
|
|
|
Args: |
|
base64_image (str): The base64-encoded image string. |
|
|
|
Returns: |
|
list[InvoiceNumbers]: A list containing lists of extracted invoice numbers. |
|
""" |
|
vertexai.init(project="saltech-ai-sandbox", location="us-central1") |
|
model = GenerativeModel("gemini-1.5-flash-001") |
|
|
|
image_part = Part.from_data( |
|
mime_type="image/png", |
|
data=base64.b64decode(base64_image), |
|
) |
|
|
|
text_prompt = """Given the remittance letter image, extract all invoice numbers. |
|
Respond with a comma-separated list of invoice numbers only. |
|
If no invoice numbers are found, respond with 'No invoice numbers found'.""" |
|
|
|
generation_config = { |
|
"max_output_tokens": 8192, |
|
"temperature": 0.1, |
|
"top_p": 0.95, |
|
} |
|
|
|
safety_settings = { |
|
generative_models.HarmCategory.HARM_CATEGORY_HATE_SPEECH: generative_models.HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE, |
|
generative_models.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: generative_models.HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE, |
|
generative_models.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: generative_models.HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE, |
|
generative_models.HarmCategory.HARM_CATEGORY_HARASSMENT: generative_models.HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE, |
|
} |
|
safety_settings = {} |
|
|
|
responses = model.generate_content( |
|
[image_part, text_prompt], |
|
generation_config=generation_config, |
|
safety_settings=safety_settings, |
|
stream=True, |
|
) |
|
|
|
full_response = "" |
|
for response in responses: |
|
full_response += response.text |
|
|
|
remittance_logger.debug(f"Extracted invoice numbers (raw model response): {full_response}") |
|
|
|
extracted_numbers = parse_gemini_response(full_response) |
|
return [extracted_numbers] |
|
|
|
def extract_column_headers(base64_image: str) -> list[str]: |
|
""" |
|
Extracts column header names that could contain invoice numbers from a base64-encoded image. |
|
|
|
Args: |
|
base64_image (str): The base64-encoded image string. |
|
|
|
Returns: |
|
list[str]: A list of column header names. |
|
""" |
|
vertexai.init(project="saltech-ai-sandbox", location="us-central1") |
|
model = GenerativeModel("gemini-1.5-flash-001") |
|
|
|
image_part = Part.from_data( |
|
mime_type="image/png", |
|
data=base64.b64decode(base64_image), |
|
) |
|
|
|
text_prompt = """Given the remittance letter image, extract all column header names that could contain invoice numbers. |
|
Respond with a comma-separated list only.""" |
|
|
|
generation_config = { |
|
"max_output_tokens": 8192, |
|
"temperature": 0.1, |
|
"top_p": 0.95, |
|
} |
|
|
|
safety_settings = { |
|
generative_models.HarmCategory.HARM_CATEGORY_HATE_SPEECH: generative_models.HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE, |
|
generative_models.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: generative_models.HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE, |
|
generative_models.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: generative_models.HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE, |
|
generative_models.HarmCategory.HARM_CATEGORY_HARASSMENT: generative_models.HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE, |
|
} |
|
safety_settings = {} |
|
|
|
responses = model.generate_content( |
|
[image_part, text_prompt], |
|
generation_config=generation_config, |
|
safety_settings=safety_settings, |
|
stream=True, |
|
) |
|
|
|
full_response = "" |
|
for response in responses: |
|
full_response += response.text |
|
|
|
remittance_logger.debug(f"Extracted column headers (raw model response): {full_response}") |
|
|
|
return [header.strip() for header in full_response.split(',')] |
|
|
|
def extract_invoice_numbers_for_column(base64_image: str, column_name: str) -> InvoiceNumbers: |
|
""" |
|
Extracts invoice numbers from a specific column in a base64-encoded image. |
|
|
|
Args: |
|
base64_image (str): The base64-encoded image string. |
|
column_name (str): The name of the column to extract invoice numbers from. |
|
|
|
Returns: |
|
InvoiceNumbers: A list of extracted invoice numbers for the specified column. |
|
""" |
|
vertexai.init(project="saltech-ai-sandbox", location="us-central1") |
|
model = GenerativeModel("gemini-1.5-flash-001") |
|
|
|
image_part = Part.from_data( |
|
mime_type="image/png", |
|
data=base64.b64decode(base64_image), |
|
) |
|
|
|
text_prompt = f"""Given the remittance letter image, extract all invoice numbers from the column "{column_name}". |
|
Respond with a comma-separated list only.""" |
|
|
|
generation_config = { |
|
"max_output_tokens": 8192, |
|
"temperature": 0.1, |
|
"top_p": 0.95, |
|
} |
|
|
|
safety_settings = { |
|
generative_models.HarmCategory.HARM_CATEGORY_HATE_SPEECH: generative_models.HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE, |
|
generative_models.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: generative_models.HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE, |
|
generative_models.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: generative_models.HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE, |
|
generative_models.HarmCategory.HARM_CATEGORY_HARASSMENT: generative_models.HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE, |
|
} |
|
safety_settings = {} |
|
|
|
responses = model.generate_content( |
|
[image_part, text_prompt], |
|
generation_config=generation_config, |
|
safety_settings=safety_settings, |
|
stream=True, |
|
) |
|
|
|
full_response = "" |
|
for response in responses: |
|
full_response += response.text |
|
|
|
remittance_logger.debug(f"Extracted invoice numbers for column '{column_name}' (raw model response): {full_response}") |
|
|
|
return [number.strip() for number in full_response.split(',') if number.strip()] |
|
|
|
def extract_invoice_numbers_with_vertex_ai_multi_hop(base64_image: str) -> list[InvoiceNumbers]: |
|
""" |
|
Extracts invoice numbers from a single base64-encoded image using Google's Gemini Flash model with multi-hop processing. |
|
|
|
Args: |
|
base64_image (str): The base64-encoded image string. |
|
|
|
Returns: |
|
list[InvoiceNumbers]: A list containing lists of extracted invoice numbers for each processed column. |
|
""" |
|
|
|
column_headers = extract_column_headers(base64_image) |
|
remittance_logger.debug(f"Extracted column headers: {column_headers}") |
|
|
|
|
|
all_invoice_numbers = [] |
|
for column_name in column_headers[:3]: |
|
invoice_numbers = extract_invoice_numbers_for_column(base64_image, column_name) |
|
remittance_logger.debug(f"Extracted invoice numbers for column '{column_name}': {invoice_numbers}") |
|
if invoice_numbers: |
|
all_invoice_numbers.append(invoice_numbers) |
|
|
|
|
|
unique_invoice_numbers = remove_duplicate_lists(all_invoice_numbers) |
|
return unique_invoice_numbers |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def parse_gemini_response(response: str) -> list[str]: |
|
""" |
|
Parses the response from Gemini Flash model and extracts invoice numbers. |
|
|
|
Args: |
|
response (str): The response string from Gemini Flash model. |
|
|
|
Returns: |
|
list[str]: A list of extracted invoice numbers. |
|
""" |
|
if response.strip().lower().startswith('no invoice numbers found'): |
|
return [] |
|
|
|
|
|
invoice_numbers = [num.strip() for num in response.split(',')] |
|
return invoice_numbers |
|
|
|
|
|
|
|
|
|
|
|
|
|
def extract_invoice_numbers_from_text_with_vertex_ai(text: str, multi_hop: bool = False) -> list[InvoiceNumbers]: |
|
""" |
|
Dispatches the invoice number extraction to either single-hop or multi-hop method based on the multi_hop parameter. |
|
|
|
Args: |
|
text (str): The text of the remittance letter. |
|
multi_hop (bool): Whether to use multi-hop processing. |
|
|
|
Returns: |
|
list[InvoiceNumbers]: A list containing lists of extracted invoice numbers. |
|
""" |
|
if multi_hop: |
|
return extract_invoice_numbers_from_text_with_vertex_ai_multi_hop(text) |
|
else: |
|
return extract_invoice_numbers_from_text_with_vertex_ai_single_hop(text) |
|
|
|
def extract_invoice_numbers_from_text_with_vertex_ai_single_hop(text: str) -> list[InvoiceNumbers]: |
|
""" |
|
Extracts invoice numbers from text using Google's Gemini Flash model with single-hop processing. |
|
|
|
Args: |
|
text (str): The text of the remittance letter. |
|
|
|
Returns: |
|
list[InvoiceNumbers]: A list containing lists of extracted invoice numbers. |
|
""" |
|
vertexai.init(project="saltech-ai-sandbox", location="us-central1") |
|
model = GenerativeModel("gemini-1.5-flash-001") |
|
|
|
prompt = f"""Given the following remittance letter text, extract all invoice numbers. |
|
Respond with a comma-separated list of invoice numbers only. |
|
If no invoice numbers are found, respond with 'No invoice numbers found'. |
|
|
|
Remittance letter text: |
|
{text} |
|
""" |
|
|
|
generation_config = { |
|
"max_output_tokens": 8192, |
|
"temperature": 0.1, |
|
"top_p": 0.95, |
|
} |
|
|
|
safety_settings = { |
|
generative_models.HarmCategory.HARM_CATEGORY_HATE_SPEECH: generative_models.HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE, |
|
generative_models.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: generative_models.HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE, |
|
generative_models.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: generative_models.HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE, |
|
generative_models.HarmCategory.HARM_CATEGORY_HARASSMENT: generative_models.HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE, |
|
} |
|
safety_settings = {} |
|
|
|
responses = model.generate_content( |
|
prompt, |
|
generation_config=generation_config, |
|
safety_settings=safety_settings, |
|
stream=True, |
|
) |
|
|
|
full_response = "" |
|
for response in responses: |
|
full_response += response.text |
|
|
|
remittance_logger.debug(f"Vertex AI invoice numbers full response (single-hop): {full_response}") |
|
|
|
extracted_numbers = parse_gemini_response(full_response) |
|
return [extracted_numbers] |
|
|
|
def extract_invoice_numbers_from_text_with_vertex_ai_multi_hop(text: str) -> list[InvoiceNumbers]: |
|
""" |
|
Extracts invoice numbers from text using Google's Gemini Flash model with multi-hop processing. |
|
|
|
Args: |
|
text (str): The text of the remittance letter. |
|
|
|
Returns: |
|
list[InvoiceNumbers]: A list containing lists of extracted invoice numbers for each processed column. |
|
""" |
|
|
|
column_headers = extract_column_headers_from_text(text) |
|
remittance_logger.debug(f"Extracted column headers: {column_headers}") |
|
|
|
|
|
all_invoice_numbers = [] |
|
for column_name in column_headers[:3]: |
|
invoice_numbers = extract_invoice_numbers_for_column_from_text(text, column_name) |
|
remittance_logger.debug(f"Extracted invoice numbers for column '{column_name}': {invoice_numbers}") |
|
if invoice_numbers: |
|
all_invoice_numbers.append(invoice_numbers) |
|
|
|
|
|
unique_invoice_numbers = remove_duplicate_lists(all_invoice_numbers) |
|
return unique_invoice_numbers |
|
|
|
def extract_column_headers_from_text(text: str) -> list[str]: |
|
""" |
|
Extracts column header names that could contain invoice numbers from the text. |
|
|
|
Args: |
|
text (str): The text of the remittance letter. |
|
|
|
Returns: |
|
list[str]: A list of column header names. |
|
""" |
|
vertexai.init(project="saltech-ai-sandbox", location="us-central1") |
|
model = GenerativeModel("gemini-1.5-flash-001") |
|
|
|
prompt = f"""Given the following remittance letter text, extract all column header names or section titles that could contain invoice numbers. |
|
Respond with a comma-separated list only. |
|
|
|
Remittance letter text: |
|
{text} |
|
""" |
|
|
|
generation_config = { |
|
"max_output_tokens": 8192, |
|
"temperature": 0.1, |
|
"top_p": 0.95, |
|
} |
|
|
|
safety_settings = { |
|
generative_models.HarmCategory.HARM_CATEGORY_HATE_SPEECH: generative_models.HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE, |
|
generative_models.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: generative_models.HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE, |
|
generative_models.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: generative_models.HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE, |
|
generative_models.HarmCategory.HARM_CATEGORY_HARASSMENT: generative_models.HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE, |
|
} |
|
safety_settings = {} |
|
|
|
response = model.generate_content( |
|
prompt, |
|
generation_config=generation_config, |
|
safety_settings=safety_settings, |
|
) |
|
|
|
remittance_logger.debug(f"Extracted column headers (raw model response): {response.text}") |
|
|
|
return [header.strip() for header in response.text.split(',')] |
|
|
|
def extract_invoice_numbers_for_column_from_text(text: str, column_name: str) -> InvoiceNumbers: |
|
""" |
|
Extracts invoice numbers from a specific column or section in the text. |
|
|
|
Args: |
|
text (str): The text of the remittance letter. |
|
column_name (str): The name of the column or section to extract invoice numbers from. |
|
|
|
Returns: |
|
InvoiceNumbers: A list of extracted invoice numbers for the specified column. |
|
""" |
|
vertexai.init(project="saltech-ai-sandbox", location="us-central1") |
|
model = GenerativeModel("gemini-1.5-flash-001") |
|
|
|
prompt = f"""Given the following remittance letter text, extract all invoice numbers from the column or section "{column_name}". |
|
Respond with a comma-separated list only. If no invoice numbers are found, respond with 'No invoice numbers found'. |
|
|
|
Remittance letter text: |
|
{text} |
|
""" |
|
|
|
generation_config = { |
|
"max_output_tokens": 8192, |
|
"temperature": 0.1, |
|
"top_p": 0.95, |
|
} |
|
|
|
safety_settings = { |
|
generative_models.HarmCategory.HARM_CATEGORY_HATE_SPEECH: generative_models.HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE, |
|
generative_models.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: generative_models.HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE, |
|
generative_models.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: generative_models.HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE, |
|
generative_models.HarmCategory.HARM_CATEGORY_HARASSMENT: generative_models.HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE, |
|
} |
|
safety_settings = {} |
|
|
|
response = model.generate_content( |
|
prompt, |
|
generation_config=generation_config, |
|
safety_settings=safety_settings, |
|
) |
|
|
|
remittance_logger.debug(f"Extracted invoice numbers for column '{column_name}' (raw model response): {response.text}") |
|
|
|
return parse_gemini_response(response.text) |
|
|
|
def extract_payment_amounts_with_vertex_ai(base64_image: str) -> list[PaymentAmount]: |
|
vertexai.init(project="saltech-ai-sandbox", location="us-central1") |
|
model = GenerativeModel("gemini-1.5-flash-001") |
|
|
|
image_part = Part.from_data( |
|
mime_type="image/png", |
|
data=base64.b64decode(base64_image), |
|
) |
|
|
|
text_prompt = """Given the remittance letter image, extract the total payment amount. |
|
Respond with the payment amount only. |
|
If no payment amounts are found, respond with 'No payment amounts found'.""" |
|
|
|
generation_config = { |
|
"max_output_tokens": 256, |
|
"temperature": 0.1, |
|
"top_p": 0.95, |
|
} |
|
|
|
safety_settings = { |
|
generative_models.HarmCategory.HARM_CATEGORY_HATE_SPEECH: generative_models.HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE, |
|
generative_models.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: generative_models.HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE, |
|
generative_models.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: generative_models.HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE, |
|
generative_models.HarmCategory.HARM_CATEGORY_HARASSMENT: generative_models.HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE, |
|
} |
|
safety_settings = {} |
|
|
|
responses = model.generate_content( |
|
[image_part, text_prompt], |
|
generation_config=generation_config, |
|
safety_settings=safety_settings, |
|
stream=True, |
|
) |
|
|
|
full_response = "" |
|
for response in responses: |
|
full_response += response.text |
|
|
|
remittance_logger.debug(f"Vertex AI payment amount full response: {full_response}") |
|
|
|
extracted_amounts = parse_gemini_payment_response(full_response) |
|
return extracted_amounts |
|
|
|
def extract_payment_amounts_from_text_with_vertex_ai(text: str) -> list[PaymentAmount]: |
|
""" |
|
Extracts payment amounts from text using Google's Gemini Flash model. |
|
|
|
Args: |
|
text (str): The text of the remittance letter. |
|
|
|
Returns: |
|
list[PaymentAmount]: A list of extracted payment amounts. |
|
""" |
|
vertexai.init(project="saltech-ai-sandbox", location="us-central1") |
|
model = GenerativeModel("gemini-1.5-flash-001") |
|
|
|
prompt = f"""Given the following remittance letter text, extract the total payment amount. |
|
Respond with the payment amount only. |
|
If no payment amounts are found, respond with 'No payment amounts found'. |
|
|
|
Remittance letter text: |
|
{text} |
|
""" |
|
|
|
generation_config = { |
|
"max_output_tokens": 256, |
|
"temperature": 0.1, |
|
"top_p": 0.95, |
|
} |
|
|
|
safety_settings = { |
|
generative_models.HarmCategory.HARM_CATEGORY_HATE_SPEECH: generative_models.HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE, |
|
generative_models.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: generative_models.HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE, |
|
generative_models.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: generative_models.HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE, |
|
generative_models.HarmCategory.HARM_CATEGORY_HARASSMENT: generative_models.HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE, |
|
} |
|
safety_settings = {} |
|
|
|
response = model.generate_content( |
|
prompt, |
|
generation_config=generation_config, |
|
safety_settings=safety_settings, |
|
) |
|
|
|
remittance_logger.debug(f"Vertex AI payment amount full response: {response.text}") |
|
|
|
extracted_amounts = parse_gemini_payment_response(response.text) |
|
return extracted_amounts |
|
|
|
def parse_gemini_payment_response(response: str) -> list[PaymentAmount]: |
|
""" |
|
Parses the response from Gemini Flash model and extracts payment amounts. |
|
|
|
Args: |
|
response (str): The response string from Gemini Flash model. |
|
|
|
Returns: |
|
list[PaymentAmount]: A list of one extracted payment amount (or empty). |
|
""" |
|
if response.strip().lower() == 'no payment amounts found': |
|
return [] |
|
|
|
payment_amounts = [response.strip()] |
|
return payment_amounts |
|
|
|
def extract_payment_amounts_from_base64_images(base64_images: list[str]) -> list[PaymentAmount]: |
|
|
|
|
|
return [] |