|
import base64 |
|
import io |
|
import json |
|
import os |
|
from typing import Any, Dict, List |
|
|
|
import chromadb |
|
import google.generativeai as palm |
|
import matplotlib.patches as patches |
|
import matplotlib.pyplot as plt |
|
import pandas as pd |
|
import requests |
|
import streamlit as st |
|
from chromadb.utils.embedding_functions import SentenceTransformerEmbeddingFunction |
|
from langchain.text_splitter import ( |
|
RecursiveCharacterTextSplitter, |
|
SentenceTransformersTokenTextSplitter, |
|
) |
|
from PIL import Image, ImageDraw, ImageFont |
|
from pypdf import PdfReader |
|
|
|
|
|
|
|
api_key = os.environ["PALM_API_KEY"] |
|
palm.configure(api_key=api_key) |
|
|
|
|
|
|
|
def convert_image_to_bytes(image): |
|
buffered = io.BytesIO() |
|
image.save(buffered, format="JPEG") |
|
return buffered.getvalue() |
|
|
|
|
|
|
|
def resize_image(image): |
|
return image.resize((512, int(image.height * 512 / image.width))) |
|
|
|
|
|
|
|
def convert_image_to_base64(image): |
|
buffered = io.BytesIO() |
|
image.save(buffered, format="JPEG") |
|
return base64.b64encode(buffered.getvalue()).decode() |
|
|
|
|
|
|
|
def call_palm(prompt: str) -> str: |
|
completion = palm.generate_text( |
|
model="models/text-bison-001", |
|
prompt=prompt, |
|
temperature=0, |
|
max_output_tokens=800, |
|
) |
|
|
|
return completion.result |
|
|
|
|
|
|
|
def call_gemini_api(image_base64, api_key=api_key, prompt="What is this picture?"): |
|
headers = { |
|
"Content-Type": "application/json", |
|
} |
|
data = { |
|
"contents": [ |
|
{ |
|
"parts": [ |
|
{"text": prompt}, |
|
{"inline_data": {"mime_type": "image/jpeg", "data": image_base64}}, |
|
] |
|
} |
|
] |
|
} |
|
response = requests.post( |
|
f"https://generativelanguage.googleapis.com/v1beta/models/gemini-pro-vision:generateContent?key={api_key}", |
|
headers=headers, |
|
json=data, |
|
) |
|
return response.json() |
|
|
|
|
|
def safely_get_text(response): |
|
try: |
|
response |
|
except Exception as e: |
|
print(f"An error occurred: {e}") |
|
|
|
|
|
return None |
|
|
|
|
|
def post_request_and_parse_response( |
|
url: str, payload: Dict[str, Any] |
|
) -> Dict[str, Any]: |
|
""" |
|
Sends a POST request to the specified URL with the given payload, |
|
then parses the byte response to a dictionary. |
|
|
|
Args: |
|
url (str): The URL to which the POST request is sent. |
|
payload (Dict[str, Any]): The payload to send in the POST request. |
|
|
|
Returns: |
|
Dict[str, Any]: The parsed dictionary from the response. |
|
""" |
|
|
|
headers = {"Content-Type": "application/json"} |
|
|
|
|
|
response = requests.post(url, json=payload, headers=headers) |
|
|
|
|
|
byte_data = response.content |
|
|
|
|
|
decoded_string = byte_data.decode("utf-8") |
|
|
|
|
|
dict_data = json.loads(decoded_string) |
|
|
|
return dict_data |
|
|
|
|
|
def extract_line_items(input_data: Dict[str, Any]) -> List[Dict[str, Any]]: |
|
""" |
|
Extracts items with "BlockType": "LINE" from the provided JSON data. |
|
|
|
Args: |
|
input_data (Dict[str, Any]): The input JSON data as a dictionary. |
|
|
|
Returns: |
|
List[Dict[str, Any]]: A list of dictionaries with the extracted data. |
|
""" |
|
|
|
line_items: List[Dict[str, Any]] = [] |
|
|
|
|
|
body_items = json.loads(input_data.get("body", "[]")) |
|
|
|
|
|
for item in body_items: |
|
|
|
if item.get("BlockType") == "LINE": |
|
|
|
line_items.append(item) |
|
|
|
return line_items |
|
|
|
|
|
def rag(query: str, retrieved_documents: list, api_key: str = api_key) -> str: |
|
""" |
|
Function to process a query and a list of retrieved documents using the Gemini API. |
|
|
|
Args: |
|
query (str): The user's query or question. |
|
retrieved_documents (list): A list of documents retrieved as relevant information to the query. |
|
api_key (str): API key for accessing the Gemini API. Default is a predefined 'api_key'. |
|
|
|
Returns: |
|
str: The cleaned output from the Gemini API response. |
|
""" |
|
|
|
information = "\n\n".join(retrieved_documents) |
|
|
|
|
|
messages = f"Question: {query}. \n Information: {information}" |
|
|
|
|
|
gemini_output = call_palm(prompt=messages) |
|
|
|
|
|
cleaned_output = gemini_output |
|
|
|
return cleaned_output |
|
|
|
|
|
def displayPDF(file: str) -> None: |
|
""" |
|
Displays a PDF file in a Streamlit application. |
|
|
|
Parameters: |
|
- file (str): The path to the PDF file to be displayed. |
|
""" |
|
|
|
|
|
with open(file, "rb") as f: |
|
|
|
base64_pdf: str = base64.b64encode(f.read()).decode("utf-8") |
|
|
|
|
|
pdf_display: str = f'<embed src="data:application/pdf;base64,{base64_pdf}" width="700" height="1000" type="application/pdf">' |
|
|
|
|
|
st.markdown(pdf_display, unsafe_allow_html=True) |
|
|
|
|
|
def draw_boxes(image: Any, predictions: List[Dict[str, Any]]) -> Any: |
|
""" |
|
Draws bounding boxes and labels onto an image based on provided predictions. |
|
|
|
Parameters: |
|
- image (Any): The image to annotate, which should support the PIL drawing interface. |
|
- predictions (List[Dict[str, Any]]): A list of predictions where each prediction is a dictionary |
|
containing 'label', 'score', and 'box' keys. The 'box' is another dictionary with 'xmin', |
|
'ymin', 'xmax', and 'ymax' as keys representing coordinates for the bounding box. |
|
|
|
Returns: |
|
- Any: The annotated image with bounding boxes and labels drawn on it. |
|
|
|
Note: |
|
- This function assumes that the incoming image supports the PIL ImageDraw interface. |
|
- The function directly modifies the input image and returns it. |
|
""" |
|
|
|
draw = ImageDraw.Draw(image) |
|
|
|
font = ImageFont.load_default() |
|
|
|
|
|
for pred in predictions: |
|
|
|
label = pred["label"] |
|
score = pred["score"] |
|
|
|
box = pred["box"] |
|
xmin, ymin, xmax, ymax = box.values() |
|
|
|
draw.rectangle([xmin, ymin, xmax, ymax], outline="green", width=1) |
|
|
|
draw.text((xmin, ymin), f"{label} ({score:.2f})", fill="red", font=font) |
|
|
|
|
|
return image |
|
|
|
|
|
def draw_bounding_boxes_for_textract( |
|
image: Image.Image, json_data: Dict[str, Any] |
|
) -> Image.Image: |
|
""" |
|
Draws bounding boxes on an image based on the provided JSON data from Textract. |
|
|
|
Args: |
|
image_path: The path to the image on which to draw bounding boxes. |
|
json_data: The JSON string containing the bounding box data from Textract. |
|
|
|
Returns: |
|
A PIL Image object with bounding boxes drawn. |
|
""" |
|
|
|
draw = ImageDraw.Draw(image) |
|
|
|
|
|
try: |
|
data = json_data |
|
blocks = json.loads(data["body"]) if "body" in data else None |
|
except json.JSONDecodeError: |
|
st.error("Invalid JSON data.") |
|
return image |
|
|
|
if blocks is None: |
|
st.error("No bounding box data found.") |
|
return image |
|
|
|
|
|
for item in blocks: |
|
if "BlockType" in item and item["BlockType"] in ["LINE", "WORD"]: |
|
bbox = item["Geometry"]["BoundingBox"] |
|
|
|
left, top, width, height = ( |
|
bbox["Left"], |
|
bbox["Top"], |
|
bbox["Width"], |
|
bbox["Height"], |
|
) |
|
|
|
left_top = (left * image.width, top * image.height) |
|
right_bottom = ((left + width) * image.width, (top + height) * image.height) |
|
|
|
draw.rectangle([left_top, right_bottom], outline="red", width=2) |
|
|
|
return image |
|
|