eagle0504 commited on
Commit
59d3355
·
1 Parent(s): c8d4109
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2023 Yiqiao Yin
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
app.py ADDED
@@ -0,0 +1,259 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import base64
2
+ import io
3
+ import json
4
+ import os
5
+ from typing import Any, Dict, List
6
+
7
+ import chromadb
8
+ import google.generativeai as palm
9
+ import pandas as pd
10
+ import requests
11
+ import streamlit as st
12
+ from chromadb.utils.embedding_functions import SentenceTransformerEmbeddingFunction
13
+ from langchain.text_splitter import (
14
+ RecursiveCharacterTextSplitter,
15
+ SentenceTransformersTokenTextSplitter,
16
+ )
17
+ from PIL import Image, ImageDraw, ImageFont
18
+ from pypdf import PdfReader
19
+ from transformers import pipeline
20
+
21
+ from utils.cnn_transformer import *
22
+ from utils.helpers import *
23
+
24
+ # API Key (You should set this in your environment variables)
25
+ api_key = st.secrets["PALM_API_KEY"]
26
+ palm.configure(api_key=api_key)
27
+
28
+
29
+ # Load YOLO pipeline
30
+ yolo_pipe = pipeline("object-detection", model="hustvl/yolos-small")
31
+
32
+
33
+ # Function to draw bounding boxes and labels on image
34
+ def draw_boxes(image, predictions):
35
+ draw = ImageDraw.Draw(image)
36
+ font = ImageFont.load_default()
37
+
38
+ for pred in predictions:
39
+ label = pred["label"]
40
+ score = pred["score"]
41
+ box = pred["box"]
42
+ xmin, ymin, xmax, ymax = box.values()
43
+ draw.rectangle([xmin, ymin, xmax, ymax], outline="red", width=2)
44
+ draw.text((xmin, ymin), f"{label} ({score:.2f})", fill="red", font=font)
45
+
46
+ return image
47
+
48
+
49
+ # Main function of the Streamlit app
50
+ def main():
51
+ st.title("Generative AI Demo on Camera Input/Image/PDF 💻")
52
+
53
+ # Dropdown for user to choose the input method
54
+ input_method = st.sidebar.selectbox(
55
+ "Choose input method:", ["Camera", "Upload Image", "Upload PDF"]
56
+ )
57
+
58
+ image, uploaded_file = None, None
59
+ if input_method == "Camera":
60
+ # Streamlit widget to capture an image from the user's webcam
61
+ image = st.sidebar.camera_input("Take a picture 📸")
62
+ elif input_method == "Upload Image":
63
+ # Create a file uploader in the sidebar
64
+ image = st.sidebar.file_uploader("Upload a JPG image", type=["jpg"])
65
+ elif input_method == "Upload PDF":
66
+ # File uploader widget
67
+ uploaded_file = st.sidebar.file_uploader("Choose a PDF file", type="pdf")
68
+
69
+ # Add instruction
70
+ st.sidebar.markdown(
71
+ """
72
+ # 🌟 How to Use the App 🌟
73
+
74
+ 1) **🌈 User Input Magic**:
75
+ - 📸 **Camera Snap**: Tap to capture a moment with your device's camera. Say cheese!
76
+ - 🖼️ **Image Upload Extravaganza**: Got a cool pic? Upload it from your computer and let the magic begin!
77
+ - 📄 **PDF Adventure**: Use gen AI as ctrl+F to search information on any PDF, like opening a treasure chest of information!
78
+ - 📄 **YOLO Algorithm**: Wanna detect the object in the image? Use our object detection algorithm to see if the objects can be detected.
79
+
80
+ 2) **🤖 AI Interaction Wonderland**:
81
+ - 🌟 **Gemini's AI**: Google's Gemini AI is your companion, ready to dive deep into your uploads.
82
+ - 🌐 **Chroma Database**: As you upload, we're crafting a colorful Chroma database in our secret lab, making your interaction even more awesome!
83
+
84
+ 3) **💬 Chit-Chat with AI Post-Upload**:
85
+ - 🌍 Once your content is up in the app, ask away! Any question, any time.
86
+ - 💡 Light up the conversation with Gemini AI. It is like having a chat with a wise wizard from the digital realm!
87
+
88
+ Enjoy exploring and have fun! 😄🎉
89
+ """
90
+ )
91
+
92
+ if image is not None:
93
+ # Display the captured image
94
+ st.image(image, caption="Captured Image", use_column_width=True)
95
+
96
+ # Convert the image to PIL format and resize
97
+ pil_image = Image.open(image)
98
+ resized_image = resize_image(pil_image)
99
+
100
+ # Convert the resized image to base64
101
+ image_base64 = convert_image_to_base64(resized_image)
102
+
103
+ # OCR by API Call of AWS Textract via Post Method
104
+ if input_method == "Upload Image":
105
+ st.success("Running textract!")
106
+ url = "https://2tsig211e0.execute-api.us-east-1.amazonaws.com/my_textract"
107
+ payload = {"image": image_base64}
108
+ result_dict = post_request_and_parse_response(url, payload)
109
+ output_data = extract_line_items(result_dict)
110
+ df = pd.DataFrame(output_data)
111
+
112
+ # Using an expander to hide the json
113
+ with st.expander("Show/Hide Raw Json"):
114
+ st.write(result_dict)
115
+
116
+ # Using an expander to hide the table
117
+ with st.expander("Show/Hide Table"):
118
+ st.table(df)
119
+
120
+ if api_key:
121
+ # Make API call
122
+ st.success("Running Gemini!")
123
+ with st.spinner('Wait for it...'):
124
+ response = call_gemini_api(image_base64, api_key)
125
+
126
+ with st.expander("Raw output from Gemini"):
127
+ st.write(response)
128
+
129
+ # Display the response
130
+ if response["candidates"][0]["content"]["parts"][0]["text"]:
131
+ text_from_response = response["candidates"][0]["content"]["parts"][0][
132
+ "text"
133
+ ]
134
+ with st.spinner("Wait for it..."):
135
+ st.write(text_from_response)
136
+
137
+ # Text input for the question
138
+ input_prompt = st.text_input(
139
+ "Type your question here:",
140
+ )
141
+
142
+ # Display the entered question
143
+ if input_prompt:
144
+ updated_text_from_response = call_gemini_api(
145
+ image_base64, api_key, prompt=input_prompt
146
+ )
147
+
148
+ if updated_text_from_response is not None:
149
+ # Do something with the text
150
+ updated_ans = updated_text_from_response["candidates"][0][
151
+ "content"
152
+ ]["parts"][0]["text"]
153
+ with st.spinner("Wait for it..."):
154
+ st.write(f"Gemini: {updated_ans}")
155
+ else:
156
+ st.warning("Check gemini's API.")
157
+
158
+ else:
159
+ st.write("No response from API.")
160
+ else:
161
+ st.write("API Key is not set. Please set the API Key.")
162
+
163
+ # YOLO
164
+ if image is not None:
165
+ st.sidebar.success("Check the following box to run YOLO algorithm if desired!")
166
+ use_yolo = st.sidebar.checkbox("Use YOLO!", value=False)
167
+
168
+ if use_yolo:
169
+ # Process image with YOLO
170
+ image = Image.open(image)
171
+ with st.spinner("Wait for it..."):
172
+ st.success("Running YOLO algorithm!")
173
+ predictions = yolo_pipe(image)
174
+ st.success("YOLO running successfully.")
175
+
176
+ # Draw bounding boxes and labels
177
+ image_with_boxes = draw_boxes(image.copy(), predictions)
178
+ st.success("Bounding boxes drawn.")
179
+
180
+ # Display annotated image
181
+ st.image(image_with_boxes, caption="Annotated Image", use_column_width=True)
182
+
183
+ # File uploader widget
184
+ if uploaded_file is not None:
185
+ # To read file as bytes:
186
+ bytes_data = uploaded_file.getvalue()
187
+ st.success("Your PDF is uploaded successfully.")
188
+
189
+ # Get the file name
190
+ file_name = uploaded_file.name
191
+
192
+ # Save the file temporarily
193
+ with open(file_name, "wb") as f:
194
+ f.write(uploaded_file.getbuffer())
195
+
196
+ # Display PDF
197
+ # displayPDF(file_name)
198
+
199
+ # Read file
200
+ reader = PdfReader(file_name)
201
+ pdf_texts = [p.extract_text().strip() for p in reader.pages]
202
+
203
+ # Filter the empty strings
204
+ pdf_texts = [text for text in pdf_texts if text]
205
+ st.success("PDF extracted successfully.")
206
+
207
+ # Split the texts
208
+ character_splitter = RecursiveCharacterTextSplitter(
209
+ separators=["\n\n", "\n", ". ", " ", ""], chunk_size=1000, chunk_overlap=0
210
+ )
211
+ character_split_texts = character_splitter.split_text("\n\n".join(pdf_texts))
212
+ st.success("Texts splitted successfully.")
213
+
214
+ # Tokenize it
215
+ st.warning("Start tokenzing ...")
216
+ token_splitter = SentenceTransformersTokenTextSplitter(
217
+ chunk_overlap=0, tokens_per_chunk=256
218
+ )
219
+ token_split_texts = []
220
+ for text in character_split_texts:
221
+ token_split_texts += token_splitter.split_text(text)
222
+ st.success("Tokenized successfully.")
223
+
224
+ # Add to vector database
225
+ embedding_function = SentenceTransformerEmbeddingFunction()
226
+ chroma_client = chromadb.Client()
227
+ chroma_collection = chroma_client.create_collection(
228
+ "tmp", embedding_function=embedding_function
229
+ )
230
+ ids = [str(i) for i in range(len(token_split_texts))]
231
+ chroma_collection.add(ids=ids, documents=token_split_texts)
232
+ st.success("Vector database loaded successfully.")
233
+
234
+ # User input
235
+ query = st.text_input("Ask me anything!", "What is the document about?")
236
+ results = chroma_collection.query(query_texts=[query], n_results=5)
237
+ retrieved_documents = results["documents"][0]
238
+ results_as_table = pd.DataFrame(
239
+ {
240
+ "ids": results["ids"][0],
241
+ "documents": results["documents"][0],
242
+ "distances": results["distances"][0],
243
+ }
244
+ )
245
+
246
+ # API of a foundation model
247
+ output = rag(query=query, retrieved_documents=retrieved_documents)
248
+ st.write(output)
249
+ st.success(
250
+ "Please see where the chatbot got the information from the document below.👇"
251
+ )
252
+ with st.expander("Raw query outputs:"):
253
+ st.write(results)
254
+ with st.expander("Processed tabular form query outputs:"):
255
+ st.table(results_as_table)
256
+
257
+
258
+ if __name__ == "__main__":
259
+ main()
figs/false-insurance-policy.jpeg ADDED
figs/labcorp_accessioning.jpg ADDED
figs/system-architect.drawio ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <mxfile host="65bd71144e">
2
+ <diagram id="6I0VWqCgP7JPpdnrNpuH" name="Page-1">
3
+ <mxGraphModel dx="721" dy="917" grid="1" gridSize="10" guides="1" tooltips="1" connect="1" arrows="1" fold="1" page="1" pageScale="1" pageWidth="850" pageHeight="1100" math="0" shadow="0">
4
+ <root>
5
+ <mxCell id="0"/>
6
+ <mxCell id="1" parent="0"/>
7
+ <mxCell id="37" value="" style="rounded=0;whiteSpace=wrap;html=1;" vertex="1" parent="1">
8
+ <mxGeometry x="80" y="110" width="720" height="480" as="geometry"/>
9
+ </mxCell>
10
+ <mxCell id="32" style="edgeStyle=none;html=1;" parent="1" source="2" target="29" edge="1">
11
+ <mxGeometry relative="1" as="geometry">
12
+ <mxPoint x="237.5" y="380" as="targetPoint"/>
13
+ </mxGeometry>
14
+ </mxCell>
15
+ <mxCell id="2" value="&lt;b&gt;PDF&lt;/b&gt;" style="html=1;verticalLabelPosition=bottom;align=center;labelBackgroundColor=#ffffff;verticalAlign=top;strokeWidth=2;strokeColor=#0080F0;shadow=0;dashed=0;shape=mxgraph.ios7.icons.documents;" parent="1" vertex="1">
16
+ <mxGeometry x="137.5" y="195" width="55" height="60" as="geometry"/>
17
+ </mxCell>
18
+ <mxCell id="12" style="html=1;entryX=1;entryY=0.5;entryDx=0;entryDy=0;entryPerimeter=0;" parent="1" source="3" target="11" edge="1">
19
+ <mxGeometry relative="1" as="geometry"/>
20
+ </mxCell>
21
+ <mxCell id="3" value="&lt;b&gt;Textract&lt;/b&gt;" style="sketch=0;points=[[0,0,0],[0.25,0,0],[0.5,0,0],[0.75,0,0],[1,0,0],[0,1,0],[0.25,1,0],[0.5,1,0],[0.75,1,0],[1,1,0],[0,0.25,0],[0,0.5,0],[0,0.75,0],[1,0.25,0],[1,0.5,0],[1,0.75,0]];outlineConnect=0;fontColor=#232F3E;gradientColor=#4AB29A;gradientDirection=north;fillColor=#116D5B;strokeColor=#ffffff;dashed=0;verticalLabelPosition=bottom;verticalAlign=top;align=center;html=1;fontSize=12;fontStyle=0;aspect=fixed;shape=mxgraph.aws4.resourceIcon;resIcon=mxgraph.aws4.textract;" parent="1" vertex="1">
22
+ <mxGeometry x="700" y="337.5" width="78" height="78" as="geometry"/>
23
+ </mxCell>
24
+ <mxCell id="15" style="edgeStyle=none;html=1;" parent="1" source="10" edge="1">
25
+ <mxGeometry relative="1" as="geometry">
26
+ <mxPoint x="590" y="420" as="targetPoint"/>
27
+ <Array as="points">
28
+ <mxPoint x="460" y="520"/>
29
+ <mxPoint x="520" y="520"/>
30
+ <mxPoint x="570" y="520"/>
31
+ </Array>
32
+ </mxGeometry>
33
+ </mxCell>
34
+ <mxCell id="31" style="edgeStyle=none;html=1;entryX=1.04;entryY=0.492;entryDx=0;entryDy=0;entryPerimeter=0;" parent="1" source="10" target="29" edge="1">
35
+ <mxGeometry relative="1" as="geometry"/>
36
+ </mxCell>
37
+ <mxCell id="10" value="&lt;b&gt;API Gateway&lt;/b&gt;" style="outlineConnect=0;dashed=0;verticalLabelPosition=bottom;verticalAlign=top;align=center;html=1;shape=mxgraph.aws3.api_gateway;fillColor=#D9A741;gradientColor=none;" parent="1" vertex="1">
38
+ <mxGeometry x="400" y="330" width="76.5" height="93" as="geometry"/>
39
+ </mxCell>
40
+ <mxCell id="13" style="edgeStyle=none;html=1;entryX=0;entryY=0.5;entryDx=0;entryDy=0;entryPerimeter=0;" parent="1" source="11" target="3" edge="1">
41
+ <mxGeometry relative="1" as="geometry"/>
42
+ </mxCell>
43
+ <mxCell id="16" style="edgeStyle=none;html=1;exitX=0.5;exitY=0;exitDx=0;exitDy=0;exitPerimeter=0;entryX=0.5;entryY=0;entryDx=0;entryDy=0;entryPerimeter=0;" parent="1" source="11" target="10" edge="1">
44
+ <mxGeometry relative="1" as="geometry">
45
+ <Array as="points">
46
+ <mxPoint x="570" y="240"/>
47
+ <mxPoint x="510" y="240"/>
48
+ <mxPoint x="450" y="240"/>
49
+ </Array>
50
+ </mxGeometry>
51
+ </mxCell>
52
+ <mxCell id="11" value="&lt;b&gt;AWS Lambda&lt;/b&gt;" style="sketch=0;points=[[0,0,0],[0.25,0,0],[0.5,0,0],[0.75,0,0],[1,0,0],[0,1,0],[0.25,1,0],[0.5,1,0],[0.75,1,0],[1,1,0],[0,0.25,0],[0,0.5,0],[0,0.75,0],[1,0.25,0],[1,0.5,0],[1,0.75,0]];outlineConnect=0;fontColor=#232F3E;gradientColor=#F78E04;gradientDirection=north;fillColor=#D05C17;strokeColor=#ffffff;dashed=0;verticalLabelPosition=bottom;verticalAlign=top;align=center;html=1;fontSize=12;fontStyle=0;aspect=fixed;shape=mxgraph.aws4.resourceIcon;resIcon=mxgraph.aws4.lambda;" parent="1" vertex="1">
53
+ <mxGeometry x="546" y="337.5" width="78" height="78" as="geometry"/>
54
+ </mxCell>
55
+ <mxCell id="22" value="&lt;b&gt;OCR Output&lt;/b&gt;" style="text;html=1;strokeColor=none;fillColor=none;align=center;verticalAlign=middle;whiteSpace=wrap;rounded=0;" parent="1" vertex="1">
56
+ <mxGeometry x="491" y="208" width="60" height="30" as="geometry"/>
57
+ </mxCell>
58
+ <mxCell id="25" value="&lt;b&gt;base64&amp;nbsp;&lt;br&gt;Encoded&lt;br&gt;Image&lt;br&gt;&lt;/b&gt;" style="text;html=1;align=center;verticalAlign=middle;resizable=0;points=[];autosize=1;strokeColor=none;fillColor=none;" parent="1" vertex="1">
59
+ <mxGeometry x="486" y="513" width="70" height="60" as="geometry"/>
60
+ </mxCell>
61
+ <mxCell id="26" value="&lt;b&gt;base64&amp;nbsp;&lt;br&gt;Encoded&lt;br&gt;Image&lt;br&gt;&lt;/b&gt;" style="text;html=1;align=center;verticalAlign=middle;resizable=0;points=[];autosize=1;strokeColor=none;fillColor=none;" parent="1" vertex="1">
62
+ <mxGeometry x="265" y="374.25" width="70" height="60" as="geometry"/>
63
+ </mxCell>
64
+ <mxCell id="28" value="&lt;b&gt;Extracted&lt;br&gt;Text&lt;br&gt;&lt;/b&gt;" style="text;html=1;align=center;verticalAlign=middle;resizable=0;points=[];autosize=1;strokeColor=none;fillColor=none;" parent="1" vertex="1">
65
+ <mxGeometry x="260" y="334.25" width="80" height="40" as="geometry"/>
66
+ </mxCell>
67
+ <mxCell id="30" style="edgeStyle=none;html=1;" parent="1" source="29" target="10" edge="1">
68
+ <mxGeometry relative="1" as="geometry"/>
69
+ </mxCell>
70
+ <mxCell id="29" value="&lt;b&gt;User&lt;/b&gt;" style="html=1;verticalLabelPosition=bottom;align=center;labelBackgroundColor=#ffffff;verticalAlign=top;strokeWidth=2;strokeColor=#0080F0;shadow=0;dashed=0;shape=mxgraph.ios7.icons.user;" parent="1" vertex="1">
71
+ <mxGeometry x="125" y="334.25" width="80" height="84.5" as="geometry"/>
72
+ </mxCell>
73
+ <mxCell id="33" value="Streamlit App" style="swimlane;whiteSpace=wrap;html=1;align=left;" vertex="1" parent="1">
74
+ <mxGeometry x="100" y="150" width="690" height="430" as="geometry"/>
75
+ </mxCell>
76
+ <mxCell id="34" value="&lt;b&gt;EC2&lt;/b&gt;" style="outlineConnect=0;dashed=0;verticalLabelPosition=bottom;verticalAlign=top;align=center;html=1;shape=mxgraph.aws3.ec2;fillColor=#F58534;gradientColor=none;" vertex="1" parent="1">
77
+ <mxGeometry x="95" y="80" width="35" height="43" as="geometry"/>
78
+ </mxCell>
79
+ </root>
80
+ </mxGraphModel>
81
+ </diagram>
82
+ </mxfile>
figs/system-architect.png ADDED
lambda/my_textract.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Purpose
3
+ An AWS lambda function that analyzes documents with Amazon Textract.
4
+ """
5
+ import json
6
+ import base64
7
+ import logging
8
+ import boto3
9
+
10
+ from botocore.exceptions import ClientError
11
+
12
+ # Set up logging.
13
+ logger = logging.getLogger(__name__)
14
+
15
+ # Get the boto3 client.
16
+ textract_client = boto3.client("textract")
17
+
18
+
19
+ def lambda_handler(event, context):
20
+ """
21
+ Lambda handler function
22
+ param: event: The event object for the Lambda function.
23
+ param: context: The context object for the lambda function.
24
+ return: The list of Block objects recognized in the document
25
+ passed in the event object.
26
+ """
27
+
28
+ # raw_image = json.loads(event['body'])['image']
29
+ # message = f"i love {country}"
30
+
31
+ # return message
32
+
33
+ try:
34
+ # Determine document source.
35
+ # event['image'] = event["queryStringParameters"]['image']
36
+ # event['image'] = json.loads(event['body'])["queryStringParameters"]['image']
37
+ event["image"] = json.loads(event["body"])["image"]
38
+ if "image" in event:
39
+ # Decode the image
40
+ image_bytes = event["image"].encode("utf-8")
41
+ img_b64decoded = base64.b64decode(image_bytes)
42
+ image = {"Bytes": img_b64decoded}
43
+
44
+ elif "S3Object" in event:
45
+ image = {
46
+ "S3Object": {
47
+ "Bucket": event["S3Object"]["Bucket"],
48
+ "Name": event["S3Object"]["Name"],
49
+ }
50
+ }
51
+
52
+ else:
53
+ raise ValueError(
54
+ "Invalid source. Only image base 64 encoded image bytes or S3Object are supported."
55
+ )
56
+
57
+ # Analyze the document.
58
+ response = textract_client.detect_document_text(Document=image)
59
+
60
+ # Get the Blocks
61
+ blocks = response["Blocks"]
62
+
63
+ lambda_response = {"statusCode": 200, "body": json.dumps(blocks)}
64
+
65
+ except ClientError as err:
66
+ error_message = "Couldn't analyze image. " + err.response["Error"]["Message"]
67
+
68
+ lambda_response = {
69
+ "statusCode": 400,
70
+ "body": {
71
+ "Error": err.response["Error"]["Code"],
72
+ "ErrorMessage": error_message,
73
+ },
74
+ }
75
+ logger.error(
76
+ "Error function %s: %s", context.invoked_function_arn, error_message
77
+ )
78
+
79
+ except ValueError as val_error:
80
+ lambda_response = {
81
+ "statusCode": 400,
82
+ "body": {"Error": "ValueError", "ErrorMessage": format(val_error)},
83
+ }
84
+ logger.error(
85
+ "Error function %s: %s", context.invoked_function_arn, format(val_error)
86
+ )
87
+
88
+ # Create return body
89
+ http_resp = {}
90
+ http_resp["statusCode"] = 200
91
+ http_resp["headers"] = {}
92
+ http_resp["headers"]["Content-Type"] = "application/json"
93
+ http_resp["body"] = json.dumps(lambda_response)
94
+
95
+ return http_resp
models/cnn_transformer/tf_keras_image_captioning_cnn+transformer_flicker8k.index ADDED
Binary file (28.9 kB). View file
 
requirements.txt ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ chromadb==0.3.29
2
+ langchain==0.0.343
3
+ matplotlib
4
+ numpy
5
+ google-generativeai>=0.1.0
6
+ pandas
7
+ pypdf==3.17.1
8
+ Pillow
9
+ sentence-transformers==2.2.2
10
+ streamlit
11
+ transformers
12
+ torch
13
+ tensorflow
utils/cnn_transformer.py ADDED
@@ -0,0 +1,379 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ os.environ["KERAS_BACKEND"] = "tensorflow"
4
+
5
+ import re
6
+ import numpy as np
7
+ import matplotlib.pyplot as plt
8
+
9
+ import tensorflow as tf
10
+ import keras
11
+ from keras import layers
12
+ from keras.applications import efficientnet
13
+ from keras.layers import TextVectorization
14
+
15
+ keras.utils.set_random_seed(111)
16
+
17
+
18
+ # Desired image dimensions
19
+ IMAGE_SIZE = (299, 299)
20
+
21
+ # Dimension for the image embeddings and token embeddings
22
+ EMBED_DIM = 512
23
+
24
+ # Per-layer units in the feed-forward network
25
+ FF_DIM = 512
26
+
27
+ # Fixed length allowed for any sequence
28
+ SEQ_LENGTH = 25
29
+
30
+ # Vocabulary size
31
+ VOCAB_SIZE = 10000
32
+
33
+ # Data augmentation for image data
34
+ image_augmentation = keras.Sequential(
35
+ [
36
+ layers.RandomFlip("horizontal"),
37
+ layers.RandomRotation(0.2),
38
+ layers.RandomContrast(0.3),
39
+ ]
40
+ )
41
+
42
+
43
+ def get_cnn_model():
44
+ base_model = efficientnet.EfficientNetB0(
45
+ input_shape=(*IMAGE_SIZE, 3),
46
+ include_top=False,
47
+ weights="imagenet",
48
+ )
49
+ # We freeze our feature extractor
50
+ base_model.trainable = False
51
+ base_model_out = base_model.output
52
+ base_model_out = layers.Reshape((-1, base_model_out.shape[-1]))(base_model_out)
53
+ cnn_model = keras.models.Model(base_model.input, base_model_out)
54
+ return cnn_model
55
+
56
+
57
+ class TransformerEncoderBlock(layers.Layer):
58
+ def __init__(self, embed_dim, dense_dim, num_heads, **kwargs):
59
+ super().__init__(**kwargs)
60
+ self.embed_dim = embed_dim
61
+ self.dense_dim = dense_dim
62
+ self.num_heads = num_heads
63
+ self.attention_1 = layers.MultiHeadAttention(
64
+ num_heads=num_heads, key_dim=embed_dim, dropout=0.0
65
+ )
66
+ self.layernorm_1 = layers.LayerNormalization()
67
+ self.layernorm_2 = layers.LayerNormalization()
68
+ self.dense_1 = layers.Dense(embed_dim, activation="relu")
69
+
70
+ def call(self, inputs, training, mask=None):
71
+ inputs = self.layernorm_1(inputs)
72
+ inputs = self.dense_1(inputs)
73
+
74
+ attention_output_1 = self.attention_1(
75
+ query=inputs,
76
+ value=inputs,
77
+ key=inputs,
78
+ attention_mask=None,
79
+ training=training,
80
+ )
81
+ out_1 = self.layernorm_2(inputs + attention_output_1)
82
+ return out_1
83
+
84
+
85
+ class PositionalEmbedding(layers.Layer):
86
+ def __init__(self, sequence_length, vocab_size, embed_dim, **kwargs):
87
+ super().__init__(**kwargs)
88
+ self.token_embeddings = layers.Embedding(
89
+ input_dim=vocab_size, output_dim=embed_dim
90
+ )
91
+ self.position_embeddings = layers.Embedding(
92
+ input_dim=sequence_length, output_dim=embed_dim
93
+ )
94
+ self.sequence_length = sequence_length
95
+ self.vocab_size = vocab_size
96
+ self.embed_dim = embed_dim
97
+ self.embed_scale = tf.math.sqrt(tf.cast(embed_dim, tf.float32))
98
+
99
+ def call(self, inputs):
100
+ length = tf.shape(inputs)[-1]
101
+ positions = tf.range(start=0, limit=length, delta=1)
102
+ embedded_tokens = self.token_embeddings(inputs)
103
+ embedded_tokens = embedded_tokens * self.embed_scale
104
+ embedded_positions = self.position_embeddings(positions)
105
+ return embedded_tokens + embedded_positions
106
+
107
+ def compute_mask(self, inputs, mask=None):
108
+ return tf.math.not_equal(inputs, 0)
109
+
110
+
111
+ class TransformerDecoderBlock(layers.Layer):
112
+ def __init__(self, embed_dim, ff_dim, num_heads, **kwargs):
113
+ super().__init__(**kwargs)
114
+ self.embed_dim = embed_dim
115
+ self.ff_dim = ff_dim
116
+ self.num_heads = num_heads
117
+ self.attention_1 = layers.MultiHeadAttention(
118
+ num_heads=num_heads, key_dim=embed_dim, dropout=0.1
119
+ )
120
+ self.attention_2 = layers.MultiHeadAttention(
121
+ num_heads=num_heads, key_dim=embed_dim, dropout=0.1
122
+ )
123
+ self.ffn_layer_1 = layers.Dense(ff_dim, activation="relu")
124
+ self.ffn_layer_2 = layers.Dense(embed_dim)
125
+
126
+ self.layernorm_1 = layers.LayerNormalization()
127
+ self.layernorm_2 = layers.LayerNormalization()
128
+ self.layernorm_3 = layers.LayerNormalization()
129
+
130
+ self.embedding = PositionalEmbedding(
131
+ embed_dim=EMBED_DIM,
132
+ sequence_length=SEQ_LENGTH,
133
+ vocab_size=VOCAB_SIZE,
134
+ )
135
+ self.out = layers.Dense(VOCAB_SIZE, activation="softmax")
136
+
137
+ self.dropout_1 = layers.Dropout(0.3)
138
+ self.dropout_2 = layers.Dropout(0.5)
139
+ self.supports_masking = True
140
+
141
+ def call(self, inputs, encoder_outputs, training, mask=None):
142
+ inputs = self.embedding(inputs)
143
+ causal_mask = self.get_causal_attention_mask(inputs)
144
+
145
+ if mask is not None:
146
+ padding_mask = tf.cast(mask[:, :, tf.newaxis], dtype=tf.int32)
147
+ combined_mask = tf.cast(mask[:, tf.newaxis, :], dtype=tf.int32)
148
+ combined_mask = tf.minimum(combined_mask, causal_mask)
149
+
150
+ attention_output_1 = self.attention_1(
151
+ query=inputs,
152
+ value=inputs,
153
+ key=inputs,
154
+ attention_mask=combined_mask,
155
+ training=training,
156
+ )
157
+ out_1 = self.layernorm_1(inputs + attention_output_1)
158
+
159
+ attention_output_2 = self.attention_2(
160
+ query=out_1,
161
+ value=encoder_outputs,
162
+ key=encoder_outputs,
163
+ attention_mask=padding_mask,
164
+ training=training,
165
+ )
166
+ out_2 = self.layernorm_2(out_1 + attention_output_2)
167
+
168
+ ffn_out = self.ffn_layer_1(out_2)
169
+ ffn_out = self.dropout_1(ffn_out, training=training)
170
+ ffn_out = self.ffn_layer_2(ffn_out)
171
+
172
+ ffn_out = self.layernorm_3(ffn_out + out_2, training=training)
173
+ ffn_out = self.dropout_2(ffn_out, training=training)
174
+ preds = self.out(ffn_out)
175
+ return preds
176
+
177
+ def get_causal_attention_mask(self, inputs):
178
+ input_shape = tf.shape(inputs)
179
+ batch_size, sequence_length = input_shape[0], input_shape[1]
180
+ i = tf.range(sequence_length)[:, tf.newaxis]
181
+ j = tf.range(sequence_length)
182
+ mask = tf.cast(i >= j, dtype="int32")
183
+ mask = tf.reshape(mask, (1, input_shape[1], input_shape[1]))
184
+ mult = tf.concat(
185
+ [
186
+ tf.expand_dims(batch_size, -1),
187
+ tf.constant([1, 1], dtype=tf.int32),
188
+ ],
189
+ axis=0,
190
+ )
191
+ return tf.tile(mask, mult)
192
+
193
+
194
+ class ImageCaptioningModel(keras.Model):
195
+ def __init__(
196
+ self,
197
+ cnn_model,
198
+ encoder,
199
+ decoder,
200
+ num_captions_per_image=5,
201
+ image_aug=None,
202
+ ):
203
+ super().__init__()
204
+ self.cnn_model = cnn_model
205
+ self.encoder = encoder
206
+ self.decoder = decoder
207
+ self.loss_tracker = keras.metrics.Mean(name="loss")
208
+ self.acc_tracker = keras.metrics.Mean(name="accuracy")
209
+ self.num_captions_per_image = num_captions_per_image
210
+ self.image_aug = image_aug
211
+
212
+ def calculate_loss(self, y_true, y_pred, mask):
213
+ loss = self.loss(y_true, y_pred)
214
+ mask = tf.cast(mask, dtype=loss.dtype)
215
+ loss *= mask
216
+ return tf.reduce_sum(loss) / tf.reduce_sum(mask)
217
+
218
+ def calculate_accuracy(self, y_true, y_pred, mask):
219
+ accuracy = tf.equal(y_true, tf.argmax(y_pred, axis=2))
220
+ accuracy = tf.math.logical_and(mask, accuracy)
221
+ accuracy = tf.cast(accuracy, dtype=tf.float32)
222
+ mask = tf.cast(mask, dtype=tf.float32)
223
+ return tf.reduce_sum(accuracy) / tf.reduce_sum(mask)
224
+
225
+ def _compute_caption_loss_and_acc(self, img_embed, batch_seq, training=True):
226
+ encoder_out = self.encoder(img_embed, training=training)
227
+ batch_seq_inp = batch_seq[:, :-1]
228
+ batch_seq_true = batch_seq[:, 1:]
229
+ mask = tf.math.not_equal(batch_seq_true, 0)
230
+ batch_seq_pred = self.decoder(
231
+ batch_seq_inp, encoder_out, training=training, mask=mask
232
+ )
233
+ loss = self.calculate_loss(batch_seq_true, batch_seq_pred, mask)
234
+ acc = self.calculate_accuracy(batch_seq_true, batch_seq_pred, mask)
235
+ return loss, acc
236
+
237
+ def train_step(self, batch_data):
238
+ batch_img, batch_seq = batch_data
239
+ batch_loss = 0
240
+ batch_acc = 0
241
+
242
+ if self.image_aug:
243
+ batch_img = self.image_aug(batch_img)
244
+
245
+ # 1. Get image embeddings
246
+ img_embed = self.cnn_model(batch_img)
247
+
248
+ # 2. Pass each of the five captions one by one to the decoder
249
+ # along with the encoder outputs and compute the loss as well as accuracy
250
+ # for each caption.
251
+ for i in range(self.num_captions_per_image):
252
+ with tf.GradientTape() as tape:
253
+ loss, acc = self._compute_caption_loss_and_acc(
254
+ img_embed, batch_seq[:, i, :], training=True
255
+ )
256
+
257
+ # 3. Update loss and accuracy
258
+ batch_loss += loss
259
+ batch_acc += acc
260
+
261
+ # 4. Get the list of all the trainable weights
262
+ train_vars = (
263
+ self.encoder.trainable_variables + self.decoder.trainable_variables
264
+ )
265
+
266
+ # 5. Get the gradients
267
+ grads = tape.gradient(loss, train_vars)
268
+
269
+ # 6. Update the trainable weights
270
+ self.optimizer.apply_gradients(zip(grads, train_vars))
271
+
272
+ # 7. Update the trackers
273
+ batch_acc /= float(self.num_captions_per_image)
274
+ self.loss_tracker.update_state(batch_loss)
275
+ self.acc_tracker.update_state(batch_acc)
276
+
277
+ # 8. Return the loss and accuracy values
278
+ return {
279
+ "loss": self.loss_tracker.result(),
280
+ "acc": self.acc_tracker.result(),
281
+ }
282
+
283
+ def test_step(self, batch_data):
284
+ batch_img, batch_seq = batch_data
285
+ batch_loss = 0
286
+ batch_acc = 0
287
+
288
+ # 1. Get image embeddings
289
+ img_embed = self.cnn_model(batch_img)
290
+
291
+ # 2. Pass each of the five captions one by one to the decoder
292
+ # along with the encoder outputs and compute the loss as well as accuracy
293
+ # for each caption.
294
+ for i in range(self.num_captions_per_image):
295
+ loss, acc = self._compute_caption_loss_and_acc(
296
+ img_embed, batch_seq[:, i, :], training=False
297
+ )
298
+
299
+ # 3. Update batch loss and batch accuracy
300
+ batch_loss += loss
301
+ batch_acc += acc
302
+
303
+ batch_acc /= float(self.num_captions_per_image)
304
+
305
+ # 4. Update the trackers
306
+ self.loss_tracker.update_state(batch_loss)
307
+ self.acc_tracker.update_state(batch_acc)
308
+
309
+ # 5. Return the loss and accuracy values
310
+ return {
311
+ "loss": self.loss_tracker.result(),
312
+ "acc": self.acc_tracker.result(),
313
+ }
314
+
315
+ @property
316
+ def metrics(self):
317
+ # We need to list our metrics here so the `reset_states()` can be
318
+ # called automatically.
319
+ return [self.loss_tracker, self.acc_tracker]
320
+
321
+
322
+
323
+ strip_chars = "!\"#$%&'()*+,-./:;<=>?@[\]^_`{|}~"
324
+ strip_chars = strip_chars.replace("<", "")
325
+ strip_chars = strip_chars.replace(">", "")
326
+
327
+
328
+ def custom_standardization(input_string):
329
+ lowercase = tf.strings.lower(input_string)
330
+ return tf.strings.regex_replace(lowercase, "[%s]" % re.escape(strip_chars), "")
331
+
332
+
333
+ vectorization = TextVectorization(
334
+ max_tokens=VOCAB_SIZE,
335
+ output_mode="int",
336
+ output_sequence_length=SEQ_LENGTH,
337
+ standardize=custom_standardization,
338
+ )
339
+
340
+
341
+ def generate_caption(caption_model: None):
342
+ # Select a random image from the validation dataset
343
+ # sample_img = np.random.choice(valid_images)
344
+
345
+ # # Read the image from the disk
346
+ # sample_img = decode_and_resize(sample_img)
347
+ # img = sample_img.numpy().clip(0, 255).astype(np.uint8)
348
+ # plt.imshow(img)
349
+ # plt.show()
350
+
351
+ # Pass the image to the CNN
352
+ # img = tf.expand_dims(sample_img, 0)
353
+ #TOOD
354
+ img = None
355
+ img = caption_model.cnn_model(img)
356
+
357
+ # Pass the image features to the Transformer encoder
358
+ encoded_img = caption_model.encoder(img, training=False)
359
+
360
+ # Generate the caption using the Transformer decoder
361
+ decoded_caption = "<start> "
362
+ vocab = vectorization.get_vocabulary()
363
+ index_lookup = dict(zip(range(len(vocab)), vocab))
364
+ max_decoded_sentence_length = SEQ_LENGTH - 1
365
+ for i in range(max_decoded_sentence_length):
366
+ tokenized_caption = vectorization([decoded_caption])[:, :-1]
367
+ mask = tf.math.not_equal(tokenized_caption, 0)
368
+ predictions = caption_model.decoder(
369
+ tokenized_caption, encoded_img, training=False, mask=mask
370
+ )
371
+ sampled_token_index = np.argmax(predictions[0, i, :])
372
+ sampled_token = index_lookup[sampled_token_index]
373
+ if sampled_token == "<end>":
374
+ break
375
+ decoded_caption += " " + sampled_token
376
+
377
+ decoded_caption = decoded_caption.replace("<start> ", "")
378
+ decoded_caption = decoded_caption.replace(" <end>", "").strip()
379
+ print("Predicted Caption: ", decoded_caption)
utils/helpers.py ADDED
@@ -0,0 +1,192 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import base64
2
+ import io
3
+ import json
4
+ import os
5
+ from typing import Any, Dict, List
6
+
7
+ import pandas as pd
8
+ import requests
9
+ import streamlit as st
10
+ from PIL import Image
11
+ import google.generativeai as palm
12
+ from pypdf import PdfReader
13
+ from langchain.text_splitter import (
14
+ RecursiveCharacterTextSplitter,
15
+ SentenceTransformersTokenTextSplitter,
16
+ )
17
+ import chromadb
18
+ from chromadb.utils.embedding_functions import SentenceTransformerEmbeddingFunction
19
+
20
+
21
+ # API Key (You should set this in your environment variables)
22
+ api_key = st.secrets["PALM_API_KEY"]
23
+ palm.configure(api_key=api_key)
24
+
25
+
26
+ # Function to convert the image to bytes for download
27
+ def convert_image_to_bytes(image):
28
+ buffered = io.BytesIO()
29
+ image.save(buffered, format="JPEG")
30
+ return buffered.getvalue()
31
+
32
+
33
+ # Function to resize the image
34
+ def resize_image(image):
35
+ return image.resize((512, int(image.height * 512 / image.width)))
36
+
37
+
38
+ # Function to convert the image to base64
39
+ def convert_image_to_base64(image):
40
+ buffered = io.BytesIO()
41
+ image.save(buffered, format="JPEG")
42
+ return base64.b64encode(buffered.getvalue()).decode()
43
+
44
+
45
+ # Function to make an API call to Palm
46
+ def call_palm(prompt: str) -> str:
47
+ completion = palm.generate_text(
48
+ model="models/text-bison-001",
49
+ prompt=prompt,
50
+ temperature=0,
51
+ max_output_tokens=800,
52
+ )
53
+
54
+ return completion.result
55
+
56
+
57
+ # Function to make an API call to Google's Gemini API
58
+ def call_gemini_api(image_base64, api_key=api_key, prompt="What is this picture?"):
59
+ headers = {
60
+ "Content-Type": "application/json",
61
+ }
62
+ data = {
63
+ "contents": [
64
+ {
65
+ "parts": [
66
+ {"text": prompt},
67
+ {"inline_data": {"mime_type": "image/jpeg", "data": image_base64}},
68
+ ]
69
+ }
70
+ ]
71
+ }
72
+ response = requests.post(
73
+ f"https://generativelanguage.googleapis.com/v1beta/models/gemini-pro-vision:generateContent?key={api_key}",
74
+ headers=headers,
75
+ json=data,
76
+ )
77
+ return response.json()
78
+
79
+
80
+ def safely_get_text(response):
81
+ try:
82
+ response
83
+ except Exception as e:
84
+ print(f"An error occurred: {e}")
85
+
86
+ # Return None or a default value if the path does not exist
87
+ return None
88
+
89
+
90
+ def post_request_and_parse_response(
91
+ url: str, payload: Dict[str, Any]
92
+ ) -> Dict[str, Any]:
93
+ """
94
+ Sends a POST request to the specified URL with the given payload,
95
+ then parses the byte response to a dictionary.
96
+
97
+ Args:
98
+ url (str): The URL to which the POST request is sent.
99
+ payload (Dict[str, Any]): The payload to send in the POST request.
100
+
101
+ Returns:
102
+ Dict[str, Any]: The parsed dictionary from the response.
103
+ """
104
+ # Set headers for the POST request
105
+ headers = {"Content-Type": "application/json"}
106
+
107
+ # Send the POST request and get the response
108
+ response = requests.post(url, json=payload, headers=headers)
109
+
110
+ # Extract the byte data from the response
111
+ byte_data = response.content
112
+
113
+ # Decode the byte data to a string
114
+ decoded_string = byte_data.decode("utf-8")
115
+
116
+ # Convert the JSON string to a dictionary
117
+ dict_data = json.loads(decoded_string)
118
+
119
+ return dict_data
120
+
121
+
122
+ def extract_line_items(input_data: Dict[str, Any]) -> List[Dict[str, Any]]:
123
+ """
124
+ Extracts items with "BlockType": "LINE" from the provided JSON data.
125
+
126
+ Args:
127
+ input_data (Dict[str, Any]): The input JSON data as a dictionary.
128
+
129
+ Returns:
130
+ List[Dict[str, Any]]: A list of dictionaries with the extracted data.
131
+ """
132
+ # Initialize an empty list to hold the extracted line items
133
+ line_items: List[Dict[str, Any]] = []
134
+
135
+ # Get the list of items from the 'body' key in the input data
136
+ body_items = json.loads(input_data.get("body", "[]"))
137
+
138
+ # Iterate through each item in the body
139
+ for item in body_items:
140
+ # Check if the BlockType of the item is 'LINE'
141
+ if item.get("BlockType") == "LINE":
142
+ # Add the item to the line_items list
143
+ line_items.append(item)
144
+
145
+ return line_items
146
+
147
+
148
+ def rag(query: str, retrieved_documents: list, api_key: str = api_key) -> str:
149
+ """
150
+ Function to process a query and a list of retrieved documents using the Gemini API.
151
+
152
+ Args:
153
+ query (str): The user's query or question.
154
+ retrieved_documents (list): A list of documents retrieved as relevant information to the query.
155
+ api_key (str): API key for accessing the Gemini API. Default is a predefined 'api_key'.
156
+
157
+ Returns:
158
+ str: The cleaned output from the Gemini API response.
159
+ """
160
+ # Combine the retrieved documents into a single string, separated by two newlines.
161
+ information = "\n\n".join(retrieved_documents)
162
+
163
+ # Format the query and combined information into a single message.
164
+ messages = f"Question: {query}. \n Information: {information}"
165
+
166
+ # Call the Gemini API with the formatted message and the API key.
167
+ gemini_output = call_palm(prompt=messages)
168
+
169
+ # Placeholder for processing the Gemini output. Currently, it simply assigns the raw output to 'cleaned_output'.
170
+ cleaned_output = gemini_output # ["candidates"][0]["content"]["parts"][0]["text"]
171
+
172
+ return cleaned_output
173
+
174
+
175
+ def displayPDF(file: str) -> None:
176
+ """
177
+ Displays a PDF file in a Streamlit application.
178
+
179
+ Parameters:
180
+ - file (str): The path to the PDF file to be displayed.
181
+ """
182
+
183
+ # Opening the PDF file in binary read mode
184
+ with open(file, "rb") as f:
185
+ # Encoding the PDF file content to base64
186
+ base64_pdf: str = base64.b64encode(f.read()).decode('utf-8')
187
+
188
+ # Creating an HTML embed string for displaying the PDF
189
+ pdf_display: str = F'<embed src="data:application/pdf;base64,{base64_pdf}" width="700" height="1000" type="application/pdf">'
190
+
191
+ # Using Streamlit to display the HTML embed string as unsafe HTML
192
+ st.markdown(pdf_display, unsafe_allow_html=True)