Spaces:
Running
Running
phyloforfun
commited on
Commit
·
e91ac58
1
Parent(s):
3cef87b
Major update. Support for 15 LLMs, World Flora Online taxonomy validation, geolocation, 2 OCR methods, significant UI changes, stability improvements, consistent JSON parsing
Browse files- README.md +340 -10
- api_cost/api_cost.yaml +75 -2
- app.py +0 -0
- custom_prompts/required_structure.yaml +0 -65
- custom_prompts/version_2.yaml +0 -232
- custom_prompts/version_2_OSU.yaml +0 -233
- demo/NY_1928185102_Heliotropiaceae_Heliotropium_indicum.jpg +3 -0
- demo/ba/ba2.jpg +3 -0
- demo/ba/ocr.jpg +3 -0
- demo/{demo_images → demo_gallery}/UM_1807464860_Phellinaceae_Phelline_dumbeensis.jpg +0 -0
- demo/demo_images/MICH_29667680_Hypericaceae_Hypericum_prolificum.jpg +3 -0
- requirements.txt +0 -0
- run_VoucherVision.py +7 -3
- settings/bottom.yaml +82 -0
- vouchervision/API_validation.py +224 -0
- vouchervision/DEP_prompt_catalog.py +1322 -0
- vouchervision/LLM_GoogleGemini.py +152 -0
- vouchervision/LLM_GooglePalm2.py +162 -0
- vouchervision/LLM_MistralAI.py +139 -0
- vouchervision/LLM_OpenAI.py +160 -0
- vouchervision/LLM_PaLM.py +0 -209
- vouchervision/LLM_chatGPT_3_5.py +0 -427
- vouchervision/LLM_local_MistralAI.py +211 -0
- vouchervision/LLM_local_MistralAI_batch.py +256 -0
- vouchervision/LLM_local_MistralAI_batch_async.py +210 -0
- vouchervision/LLM_local_cpu_MistralAI.py +205 -0
- vouchervision/LM2_logger.py +17 -6
- vouchervision/OCR_google_cloud_vision.py +677 -20
- vouchervision/OCR_trOCR.py +0 -0
- vouchervision/VoucherVision_Config_Builder.py +132 -112
- vouchervision/VoucherVision_GUI.py +0 -0
- vouchervision/embed_occ.py +1 -1
- vouchervision/embeddings_db.py +1 -1
- vouchervision/general_utils.py +37 -39
- vouchervision/model_maps.py +232 -0
- vouchervision/prompt_catalog.py +35 -1235
- vouchervision/utils_LLM.py +102 -0
- vouchervision/utils_LLM_JSON_validation.py +170 -0
- vouchervision/utils_VoucherVision.py +491 -388
- vouchervision/utils_VoucherVision_batch.py +863 -0
- vouchervision/utils_geolocate_HERE.py +319 -0
- vouchervision/utils_geolocate_OpenCage.py +95 -0
- vouchervision/{utils.py → utils_hf.py} +0 -0
- vouchervision/utils_taxonomy_WFO.py +319 -0
- vouchervision/vouchervision_main.py +10 -35
README.md
CHANGED
@@ -1,13 +1,343 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
---
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
---
|
12 |
|
13 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# VoucherVision
|
2 |
+
|
3 |
+
[![VoucherVision](https://LeafMachine.org/img/VV_Logo.png "VoucherVision")](https://LeafMachine.org/)
|
4 |
+
|
5 |
+
Table of Contents
|
6 |
+
=================
|
7 |
+
|
8 |
+
* [Table of Contents](#table-of-contents)
|
9 |
+
* [About](#about)
|
10 |
+
* [Roadmap and New Features List](#roadmap-and-new-features-list)
|
11 |
+
* [Try our public demo!](#try-our-public-demo)
|
12 |
+
* [Installing VoucherVision](#installing-VoucherVision)
|
13 |
+
* [Prerequisites](#prerequisites)
|
14 |
+
* [Installation - Cloning the VoucherVision Repository](#installation---cloning-the-VoucherVision-repository)
|
15 |
+
* [About Python Virtual Environments](#about-python-virtual-environments)
|
16 |
+
* [Installation - Windows 10+](#installation---windows-10)
|
17 |
+
* [Virtual Environment](#virtual-environment-1)
|
18 |
+
* [Installing Packages](#installing-packages-1)
|
19 |
+
* [Troubleshooting CUDA](#troubleshooting-cuda)
|
20 |
+
* [Create a Desktop Shortcut to Launch VoucherVision GUI](#create-a-desktop-shortcut-to-launch-vouchervision-gui)
|
21 |
+
* [Run VoucherVision](#run-vouchervision)
|
22 |
+
* [Setting up API key](#setting-up-api-key)
|
23 |
+
* [Check GPU](#check-gpu)
|
24 |
+
* [Run Tests](#run-tests)
|
25 |
+
* [Starting VoucherVision](#starting-vouchervision)
|
26 |
+
* [Azure Instances of OpenAI](#azure-instances-of-openai)
|
27 |
+
* [Custom Prompt Builder](#custom-prompt-builder)
|
28 |
+
* [Load, Build, Edit](#load-build-edit)
|
29 |
+
* [Instructions](#instructions)
|
30 |
+
* [Defining Column Names Field-Specific Instructions](#defining-column-names-field-specific-instructions)
|
31 |
+
* [Prompting Structure](#prompting-structure)
|
32 |
+
* [Mapping Columns for VoucherVisionEditor](#mapping-columns-for-vouchervisioneditor)
|
33 |
+
* [Expense Reporting](#expense-reporting)
|
34 |
+
* [Expense Report Dashboard](#expense-report-dashboard)
|
35 |
+
* [User Interface Images](#user-interface-images)
|
36 |
+
|
37 |
+
---
|
38 |
+
|
39 |
+
# About
|
40 |
+
## **VoucherVision** - In Beta Testing Phase 🚀
|
41 |
+
|
42 |
+
For inquiries, feedback (or if you want to get involved!) [please complete our form](https://docs.google.com/forms/d/e/1FAIpQLSe2E9zU1bPJ1BW4PMakEQFsRmLbQ0WTBI2UXHIMEFm4WbnAVw/viewform?usp=sf_link).
|
43 |
+
|
44 |
+
## **Overview:**
|
45 |
+
Initiated by the **University of Michigan Herbarium**, VoucherVision harnesses the power of large language models (LLMs) to transform the transcription process of natural history specimen labels. Our workflow is as follows:
|
46 |
+
- Text extraction from specimen labels with **LeafMachine2**.
|
47 |
+
- Text interpretation using **Google Vision OCR**.
|
48 |
+
- LLMs, including ***GPT-3.5***, ***GPT-4***, ***PaLM 2***, and Azure instances of OpenAI models, standardize the OCR output into a consistent spreadsheet format. This data can then be integrated into various databases like Specify, Symbiota, and BRAHMS.
|
49 |
+
|
50 |
+
For ensuring accuracy and consistency, the [VoucherVisionEditor](https://github.com/Gene-Weaver/VoucherVisionEditor) serves as a quality control tool.
|
51 |
+
|
52 |
+
## Roadmap and New Features List
|
53 |
+
|
54 |
+
#### VoucherVision
|
55 |
+
- [X] Update to GPT 1106 builds
|
56 |
+
- [ ] Option to zip output files for simpler import into VVE
|
57 |
+
- [ ] Instead of saving a copy of the original image inplace of the OCR/collage images when they are not selected, just change the path to the original image.
|
58 |
+
- [x] Expense tracking
|
59 |
+
- [x] Dashboard
|
60 |
+
- [X] More granular support for different GPT versions
|
61 |
+
- [x] Project-based and cummulative tracking
|
62 |
+
- [x] Hugging Face Spaces
|
63 |
+
- [x] Working and refactored
|
64 |
+
- [ ] Visualize locations on a map (verbatim and decimal)
|
65 |
+
- [x] Tested with batch of 300 images
|
66 |
+
- [x] GPT 3.5
|
67 |
+
- [ ] GPT 4
|
68 |
+
- [ ] PaLM 2
|
69 |
+
- [ ] Optimize for +300 images at a time
|
70 |
+
- [x] Modular Prompt Builder
|
71 |
+
- [x] Build, save, load, submit to VV library
|
72 |
+
- [ ] Assess whether order of column matters
|
73 |
+
- [ ] Assess shorter prompt effectiveness
|
74 |
+
- [ ] Restrict special columns to conform with VVE requirements (catalog_number, coordinates)
|
75 |
+
- [ ] Option to load existing OCR into VoucherVision workflow
|
76 |
+
#### Supported LLM APIs
|
77 |
+
- [x] OpenAI
|
78 |
+
- [x] GPT 4
|
79 |
+
- [x] GPT 4 Turbo 1106-preview
|
80 |
+
- [x] GPT 4 32k
|
81 |
+
- [x] GPT 3.5
|
82 |
+
- [x] GPT 3.5 Instruct
|
83 |
+
- [x] OpenAI (Microsoft Azure Endpoints)
|
84 |
+
- [x] GPT 4
|
85 |
+
- [x] GPT 4 Turbo 1106-preview
|
86 |
+
- [x] GPT 4 32k
|
87 |
+
- [x] GPT 3.5
|
88 |
+
- [x] GPT 3.5 Instruct
|
89 |
+
- [x] MistralAI
|
90 |
+
- [x] Mistral Tiny
|
91 |
+
- [x] Mistral Small
|
92 |
+
- [x] Mistral Medium
|
93 |
+
- [x] Google PaLM2
|
94 |
+
- [x] text-bison@001
|
95 |
+
- [x] text-bison@002
|
96 |
+
- [x] text-unicorn@001
|
97 |
+
- [x] Google Gemini
|
98 |
+
- [x] Gemini-Pro
|
99 |
+
#### Supported Locally Hosted LLMs
|
100 |
+
- [x] MistralAI (24GB+ VRAM GPU Required)
|
101 |
+
- [x] Mixtral 8x7B Instruct v0.1
|
102 |
+
- [x] Mixtral 7B Instruct v0.2
|
103 |
+
- [x] MistralAI (CPU Inference) ((can run on almost computer!))
|
104 |
+
- [x] Mixtral 7B Instruct v0.2 GGUF via llama.cpp
|
105 |
+
- [x] Meta-Llama2 7B
|
106 |
+
- [ ] Llama2 7B chat hf
|
107 |
+
|
108 |
+
#### VoucherVisionEditor
|
109 |
+
- [ ] Streamline the startup procedure
|
110 |
+
- [ ] Add configurable dropdown menus for certain fields
|
111 |
+
- [ ] Make sure that VVE can accomodate arbitrary column names
|
112 |
+
- [ ] Remove legacy support (version 1 prompts)
|
113 |
+
- [ ] Taxonomy validation helper
|
114 |
+
- [x] Visualize locations on a map (verbatim and decimal)
|
115 |
+
- [ ] More support for datum and verbatim coordinates
|
116 |
+
- [ ] Compare raw OCR to values in form to flag hallucinations/generated content
|
117 |
+
- [ ] Accept zipped folders as input
|
118 |
+
- [ ] Flag user when multiple people/names/determinations are present
|
119 |
+
|
120 |
+
### **Package Information:**
|
121 |
+
The main VoucherVision tool and the VoucherVisionEditor are packaged separately. This separation ensures that lower-performance computers can still install and utilize the editor. While VoucherVision is optimized to function smoothly on virtually any modern system, maximizing its capabilities (like using LeafMachine2 label collages or running Retrieval Augmented Generation (RAG) prompts) mandates a GPU.
|
122 |
+
|
123 |
+
> ***NOTE:*** You can absolutely run VoucherVision on non-GPU systems, but RAG will not be possible (luckily the apparent best prompts 'Version2+' does not use RAG).
|
124 |
+
|
125 |
+
---
|
126 |
+
|
127 |
+
# Try our public demo!
|
128 |
+
Our public demo, while lacking several quality control and reliability features found in the full VoucherVision module, provides an exciting glimpse into its capabilities. Feel free to upload your herbarium specimen and see what happens!
|
129 |
+
[VoucherVision Demo](https://huggingface.co/spaces/phyloforfun/VoucherVision)
|
130 |
+
|
131 |
+
---
|
132 |
+
|
133 |
+
# Installing VoucherVision
|
134 |
+
|
135 |
+
## Prerequisites
|
136 |
+
- Python 3.10 or later
|
137 |
+
- Optional: an Nvidia GPU + CUDA for running LeafMachine2
|
138 |
+
|
139 |
+
## Installation - Cloning the VoucherVision Repository
|
140 |
+
1. First, install Python 3.10, or greater, on your machine of choice. We have validated up to Python 3.11.
|
141 |
+
- Make sure that you can use `pip` to install packages on your machine, or at least inside of a virtual environment.
|
142 |
+
- Simply type `pip` into your terminal or PowerShell. If you see a list of options, you are all set. Otherwise, see
|
143 |
+
either this [PIP Documentation](https://pip.pypa.io/en/stable/installation/) or [this help page](https://www.geeksforgeeks.org/how-to-install-pip-on-windows/)
|
144 |
+
2. Open a terminal window and `cd` into the directory where you want to install VoucherVision.
|
145 |
+
3. In the [Git BASH terminal](https://gitforwindows.org/), clone the VoucherVision repository from GitHub by running the command:
|
146 |
+
<pre><code class="language-python">git clone https://github.com/Gene-Weaver/VoucherVision.git</code></pre>
|
147 |
+
<button class="btn" data-clipboard-target="#code-snippet"></button>
|
148 |
+
4. Move into the VoucherVision directory by running `cd VoucherVision` in the terminal.
|
149 |
+
5. To run VoucherVision we need to install its dependencies inside of a python virtual environmnet. Follow the instructions below for your operating system.
|
150 |
+
|
151 |
+
## About Python Virtual Environments
|
152 |
+
A virtual environment is a tool to keep the dependencies required by different projects in separate places, by creating isolated python virtual environments for them. This avoids any conflicts between the packages that you have installed for different projects. It makes it easier to maintain different versions of packages for different projects.
|
153 |
+
|
154 |
+
For more information about virtual environments, please see [Creation of virtual environments](https://docs.python.org/3/library/venv.html)
|
155 |
+
|
156 |
+
---
|
157 |
+
|
158 |
+
## Installation - Windows 10+
|
159 |
+
Installation should basically be the same for Linux.
|
160 |
+
### Virtual Environment
|
161 |
+
|
162 |
+
1. Still inside the VoucherVision directory, show that a venv is currently not active
|
163 |
+
<pre><code class="language-python">python --version</code></pre>
|
164 |
+
<button class="btn" data-clipboard-target="#code-snippet"></button>
|
165 |
+
2. Then create the virtual environment (venv_VV is the name of our new virtual environment)
|
166 |
+
<pre><code class="language-python">python3 -m venv venv_VV</code></pre>
|
167 |
+
<button class="btn" data-clipboard-target="#code-snippet"></button>
|
168 |
+
Or depending on your Python version...
|
169 |
+
<pre><code class="language-python">python -m venv venv_VV</code></pre>
|
170 |
+
<button class="btn" data-clipboard-target="#code-snippet"></button>
|
171 |
+
3. Activate the virtual environment
|
172 |
+
<pre><code class="language-python">.\venv_VV\Scripts\activate</code></pre>
|
173 |
+
<button class="btn" data-clipboard-target="#code-snippet"></button>
|
174 |
+
4. Confirm that the venv is active (should be different from step 1)
|
175 |
+
<pre><code class="language-python">python --version</code></pre>
|
176 |
+
<button class="btn" data-clipboard-target="#code-snippet"></button>
|
177 |
+
5. If you want to exit the venv later for some reason, deactivate the venv using
|
178 |
+
<pre><code class="language-python">deactivate</code></pre>
|
179 |
+
<button class="btn" data-clipboard-target="#code-snippet"></button>
|
180 |
+
|
181 |
+
### Installing Packages
|
182 |
+
|
183 |
+
1. Install the required dependencies to use VoucherVision
|
184 |
+
- Option A - If you are using Windows PowerShell:
|
185 |
+
<pre><code class="language-python">pip install wheel streamlit streamlit-extras plotly pyyaml Pillow pandas matplotlib matplotlib-inline tqdm openai langchain tiktoken openpyxl google-generativeai google-cloud-storage google-cloud-vision opencv-python chromadb chroma-migrate InstructorEmbedding transformers sentence-transformers seaborn dask psutil py-cpuinfo azureml-sdk azure-identity ; if ($?) { pip install numpy -U } ; if ($?) { pip install -U scikit-learn } ; if ($?) { pip install --upgrade numpy scikit-learnstreamlit google-generativeai google-cloud-storage google-cloud-vision azureml-sdk azure-identity openai langchain }</code></pre>
|
186 |
+
<button class="btn" data-clipboard-target="#code-snippet"></button>
|
187 |
+
|
188 |
+
- Option B:
|
189 |
+
<pre><code class="language-python">pip install wheel streamlit streamlit-extras plotly pyyaml Pillow pandas matplotlib matplotlib-inline tqdm openai langchain tiktoken openpyxl google-generativeai google-cloud-storage google-cloud-vision opencv-python chromadb chroma-migrate InstructorEmbedding transformers sentence-transformers seaborn dask psutil py-cpuinfo azureml-sdk azure-identity</code></pre>
|
190 |
+
<button class="btn" data-clipboard-target="#code-snippet"></button>
|
191 |
+
|
192 |
+
2. Upgrade important packages. Run this if there is an update to VoucherVision.
|
193 |
+
<pre><code class="language-python">pip install --upgrade numpy scikit-learn streamlit google-generativeai google-cloud-storage google-cloud-vision azureml-sdk azure-identity openai langchain</code></pre>
|
194 |
+
<button class="btn" data-clipboard-target="#code-snippet"></button>
|
195 |
+
|
196 |
+
3. Install PyTorch
|
197 |
+
- The LeafMachine2 machine learning algorithm requires PyTorch. If your computer does not have a GPU, then please install a version of PyTorch that is for CPU only. If your computer does have an Nvidia GPU, then please determine which version of PyTorch matches your current CUDA version. Please see [Troubleshooting CUDA](#troubleshooting-cuda) for help. PyTorch is large and will take a bit to install.
|
198 |
+
|
199 |
+
- WITH GPU (or visit [PyTorch.org](https://pytorch.org/get-started/locally/) to find the appropriate version of PyTorch for your CUDA version)
|
200 |
+
<pre><code class="language-python">pip install torch==1.11.0+cu113 torchvision==0.12.0+cu113 torchaudio==0.11.0 --extra-index-url https://download.pytorch.org/whl/cu113</code></pre>
|
201 |
+
<button class="btn" data-clipboard-target="#code-snippet"></button>
|
202 |
+
- WITHOUT GPU, CPU ONLY
|
203 |
+
<pre><code class="language-python">pip install torch==2.0.1 torchvision==0.15.2 torchaudio==2.0.2 --index-url https://download.pytorch.org/whl/cpu</code></pre>
|
204 |
+
<button class="btn" data-clipboard-target="#code-snippet"></button>
|
205 |
+
|
206 |
+
|
207 |
+
> If you need help, please submit an inquiry in the form at [LeafMachine.org](https://LeafMachine.org/)
|
208 |
+
|
209 |
+
---
|
210 |
+
|
211 |
+
## Troubleshooting CUDA
|
212 |
+
|
213 |
+
- If your system already has another version of CUDA (e.g., CUDA 11.7) then it can be complicated to switch to CUDA 11.3.
|
214 |
+
- The simplest solution is to install pytorch with CPU only, avoiding the CUDA problem entirely.
|
215 |
+
- Alternatively, you can install the [latest pytorch release](https://pytorch.org/get-started/locally/) for your specific system, either using the cpu only version `pip3 install torch`, `pip3 install torchvision`, `pip3 install torchaudio` or by matching the pythorch version to your CUDA version.
|
216 |
+
- We have not validated CUDA 11.6 or CUDA 11.7, but our code is likely to work with them too. If you have success with other versions of CUDA/pytorch, let us know and we will update our instructions.
|
217 |
+
|
218 |
+
---
|
219 |
+
|
220 |
+
# Create a Desktop Shortcut to Launch VoucherVision GUI
|
221 |
+
We can create a desktop shortcut to launch VoucherVision. In the `../VoucherVision/` directory is a file called `create_desktop_shortcut.py`. In the terminal, move into the `../VoucherVision/` directory and type:
|
222 |
+
<pre><code class="language-python">python create_desktop_shortcut.py</code></pre>
|
223 |
+
<button class="btn" data-clipboard-target="#code-snippet"></button>
|
224 |
+
Or...
|
225 |
+
<pre><code class="language-python">python3 create_desktop_shortcut.py</code></pre>
|
226 |
+
<button class="btn" data-clipboard-target="#code-snippet"></button>
|
227 |
+
Follow the instructions, select where you want the shortcut to be created, then where the virtual environment is located.
|
228 |
+
|
229 |
---
|
230 |
+
|
231 |
+
# Run VoucherVision
|
232 |
+
1. In the terminal, make sure that you `cd` into the `VoucherVision` directory and that your virtual environment is active (you should see venv_VV on the command line).
|
233 |
+
2. Type:
|
234 |
+
<pre><code class="language-python">python run_VoucherVision.py</code></pre>
|
235 |
+
<button class="btn" data-clipboard-target="#code-snippet"></button>
|
236 |
+
or depending on your Python installation:
|
237 |
+
<pre><code class="language-python">python3 run_VoucherVision.py</code></pre>
|
238 |
+
<button class="btn" data-clipboard-target="#code-snippet"></button>
|
239 |
+
3. If you ever see an error that says that a "port is not available", open `run_VoucherVision.py` in a plain text editor and change the `--port` value to something different but close, like 8502.
|
240 |
+
|
241 |
+
## Setting up API key
|
242 |
+
VoucherVision requires access to Google Vision OCR and at least one of the following LLMs: OpenAI API, Google PaLM 2, a private instance of OpenAI through Microsoft Azure. On first startup, you will see a page with instructions on how to get these API keys. ***Nothing will work until*** you get at least the Google Vision OCR API key and one LLM API key.
|
243 |
+
|
244 |
+
## Check GPU
|
245 |
+
Press the "Check GPU" button to see if you have a GPU available. If you know that your computer has an Nvidia GPU, but the check fails, then you need to install an different version of PyTorch in the virtual environment.
|
246 |
+
|
247 |
+
## Run Tests
|
248 |
+
Once you have provided API keys, you can test all available prompts and LLMs by pressing the test buttons. Every combination of LLM, prompt, and LeafMachine2 collage will run on the image in the `../VoucherVision/demo/demo_images` folder. A grid will appear letting you know which combinations are working on your system.
|
249 |
+
|
250 |
+
## Starting VoucherVision
|
251 |
+
1. "Run name" - Set a run name for your project. This will be the name of the new folder that contains the output files.
|
252 |
+
2. "Output directory" - Paste the full file path of where you would like to save the folder that will be created in step 1.
|
253 |
+
3. "Input images directory" - Paste the full file path of where the input images are located. This folder can only have JPG or JPEG images inside of it.
|
254 |
+
4. "Select an LLM" - Pick the LLM you want to use to parse the unstructured OCR text.
|
255 |
+
- As of Nov. 1, 2023 PaLM 2 is free to use.
|
256 |
+
5. "Prompt Version" - Pick your prompt version. We recommend "Version 2" for production use, but you can experiment with our other prompts.
|
257 |
+
6. "Cropped Components" - Check the box to use LeafMachine2 collage images as the input file. LeafMachine2 can often find small handwritten text that may be missed by Google Vision OCR's text detection algorithm. But, the difference in performance is not that big. You will still get good performance without using the LeafMachine2 collage images.
|
258 |
+
7. "Domain Knowledge" is only used for "Version 1" prompts.
|
259 |
+
8. "Component Detector" sets basic LeafMachine2 parameters, but the default is likely good enough.
|
260 |
+
9. "Processing Options"
|
261 |
+
- The image file name defines the row name in the final output spreadsheet.
|
262 |
+
- We provide some basic options to clean/parse the image file name to produce the desired output.
|
263 |
+
- For example, if the input image name is `MICH-V-3819482.jpg` but the desired name is just `3819482` you can add `MICH-V-` to the "Remove prefix from catalog number" input box. Alternatively, you can check the "Require Catalog..." box and achieve the same result.
|
264 |
+
|
265 |
+
10. ***Finally*** you can press the start processing button.
|
266 |
+
|
267 |
+
## Azure Instances of OpenAI
|
268 |
+
If your institution has an enterprise instance of OpenAI's services, [like at the University of Michigan](https://its.umich.edu/computing/ai), you can use Azure instead of the OpenAI servers. Your institution should be able to provide you with the required keys (there are 5 required keys for this service).
|
269 |
+
|
270 |
+
# Custom Prompt Builder
|
271 |
+
VoucherVision empowers individual institutions to customize the format of the LLM output. Using our pre-defined prompts you can transcribe the label text into 20 columns, but using our Prompt Builder you can load one of our default prompts and adjust the output to meet your needs. More instructions will come soon, but for now here are a few more details.
|
272 |
+
|
273 |
+
### Load, Build, Edit
|
274 |
+
|
275 |
+
The Prompt Builder creates a prompt in the structure that VoucherVision expects. This information is stored as a configuration yaml file in `../VoucherVision/custom_prompts/`. We provide a few versions to get started. You can load one of our examples and then use the Prompt Builder to edit or add new columns.
|
276 |
+
|
277 |
+
![prompt_1](https://LeafMachine.org/img/prompt_1.PNG)
|
278 |
+
|
279 |
+
### Instructions
|
280 |
+
|
281 |
+
Right now, the prompting instructions are not configurable, but that may change in the future.
|
282 |
+
|
283 |
+
![prompt_2](https://LeafMachine.org/img/prompt_1.PNG)
|
284 |
+
|
285 |
+
### Defining Column Names Field-Specific Instructions
|
286 |
+
|
287 |
+
The central JSON object shows the structure of the columns that you are requesting the LLM to create and populate with information from the specimen's labels. These will become the rows in the final xlsx file the VoucherVision generates. You can pick formatting instructions, set default values, and give detailed instructions.
|
288 |
+
|
289 |
+
> Note: formatting instructions are not always followed precisely by the LLM. For example, GPT-4 is capable of granular instructions like converting ALL CAPS TEXT to sentence-case, but GPT-3.5 and PaLM 2 might not be capable of following that instruction every time (which is why we have the VoucherVisionEditor and are working to link these instructions so that humans editing the output can quickly/easily fix these errors).
|
290 |
+
|
291 |
+
![prompt_3](https://LeafMachine.org/img/prompt_3.PNG)
|
292 |
+
|
293 |
+
### Prompting Structure
|
294 |
+
|
295 |
+
The rightmost JSON object is the entire prompt structure. If you load the `required_structure.yaml` prompt, you will wee the bare-bones version of what VoucherVision expects to see. All of the parts are there for a reason. The Prompt Builder UI may be a little unruly right now thanks to quirks with Streamlit, but we still recommend using the UI to build your own prompts to make sure that all of the required components are present.
|
296 |
+
|
297 |
+
![prompt_4](https://LeafMachine.org/img/prompt_4.PNG)
|
298 |
+
|
299 |
+
### Mapping Columns for VoucherVisionEditor
|
300 |
+
|
301 |
+
Finally, we need to map columns to a VoucherVisionEditor category.
|
302 |
+
|
303 |
+
![prompt_5](https://LeafMachine.org/img/prompt_5.PNG)
|
304 |
+
|
305 |
+
# Expense Reporting
|
306 |
+
VoucherVision logs the number of input and output tokens (using [tiktoken](https://github.com/openai/tiktoken)) from every call. We store the publicly listed prices of the LLM APIs in `../VoucherVision/api_cost/api_cost.yaml`. Then we do some simple math to estimage the cost of run, which is stored inside of your project's output directory `../run_name/Cost/run_name.csv` and all runs are accumulated in a csv file stored in `../VoucherVision/expense_report/expense_report.csv`. VoucherVision only manages `expense_report.csv`, so if you want to split costs by month/quarter then copy and rename `expense_report.csv`. Deleting `expense_report.csv` will let you accumulate more stats.
|
307 |
+
|
308 |
+
> This should be treated as an estimate. The true cost may be slightly different.
|
309 |
+
|
310 |
+
This is an example of the stats that we track:
|
311 |
+
| run | date | api_version | total_cost | n_images | tokens_in | tokens_out | rate_in | rate_out | cost_in | cost_out |
|
312 |
+
|----------------------------|--------------------------|-------------|------------|----------|-----------|------------|---------|----------|-----------|----------|
|
313 |
+
| GPT4_test_run1 | 2023_11_05__17-44-31 | GPT_4 | 0.23931 | 2 | 6749 | 614 | 0.03 | 0.06 | 0.20247 | 0.03684 |
|
314 |
+
| GPT_3_5_test_run | 2023_11_05__17-48-48 | GPT_3_5 | 0.0189755 | 4 | 12033 | 463 | 0.0015 | 0.002 | 0.0180495 | 0.000926 |
|
315 |
+
| PALM2_test_run | 2023_11_05__17-50-35 | PALM2 | 0 | 4 | 13514 | 771 | 0 | 0 | 0 | 0 |
|
316 |
+
| GPT4_test_run2 | 2023_11_05__18-49-24 | GPT_4 | 0.40962 | 4 | 12032 | 811 | 0.03 | 0.06 | 0.36096 | 0.04866 |
|
317 |
+
|
318 |
+
## Expense Report Dashboard
|
319 |
+
The sidebar in VoucherVision displays summary stats taken from `expense_report.csv`.
|
320 |
+
![Expense Report Dashboard](https://LeafMachine.org/img/expense_report.PNG)
|
321 |
+
|
322 |
+
# User Interface Images
|
323 |
+
Validation test when the OpenAI key is not provided, but keys for PaLM 2 and Azure OpenAI are present:
|
324 |
+
![Validation 1](https://LeafMachine.org/img/validation_1.PNG)
|
325 |
+
|
326 |
---
|
327 |
|
328 |
+
Validation test when all versions of the OpenAI keys are provided:
|
329 |
+
![Validation GPT](https://LeafMachine.org/img/validation_gpt.PNG)
|
330 |
+
|
331 |
+
---
|
332 |
+
|
333 |
+
A successful GPU test:
|
334 |
+
![Validation GPU](https://LeafMachine.org/img/validation_gpu.PNG)
|
335 |
+
|
336 |
+
---
|
337 |
+
|
338 |
+
Successful PaLM 2 test:
|
339 |
+
![Validation PaLM](https://LeafMachine.org/img/validation_palm.PNG)
|
340 |
+
|
341 |
+
|
342 |
+
|
343 |
+
|
api_cost/api_cost.yaml
CHANGED
@@ -1,9 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
GPT_3_5:
|
2 |
in: 0.0010
|
3 |
out: 0.0020
|
4 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
5 |
in: 0.01
|
6 |
out: 0.03
|
7 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
8 |
in: 0.0
|
9 |
out: 0.0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# OpenAI
|
2 |
+
# https://openai.com/pricing
|
3 |
+
|
4 |
+
GPT_4_32K:
|
5 |
+
in: 0.06
|
6 |
+
out: 0.12
|
7 |
+
GPT_4:
|
8 |
+
in: 0.03
|
9 |
+
out: 0.06
|
10 |
+
GPT_4_TURBO:
|
11 |
+
in: 0.01
|
12 |
+
out: 0.03
|
13 |
+
GPT_3_5_INSTRUCT:
|
14 |
+
in: 0.0015
|
15 |
+
out: 0.0020
|
16 |
GPT_3_5:
|
17 |
in: 0.0010
|
18 |
out: 0.0020
|
19 |
+
|
20 |
+
# Azure
|
21 |
+
AZURE_GPT_4_32K:
|
22 |
+
in: 0.06
|
23 |
+
out: 0.12
|
24 |
+
AZURE_GPT_4:
|
25 |
+
in: 0.03
|
26 |
+
out: 0.06
|
27 |
+
AZURE_GPT_4_TURBO:
|
28 |
in: 0.01
|
29 |
out: 0.03
|
30 |
+
AZURE_GPT_3_5_INSTRUCT:
|
31 |
+
in: 0.0015
|
32 |
+
out: 0.0020
|
33 |
+
AZURE_GPT_3_5:
|
34 |
+
in: 0.0010
|
35 |
+
out: 0.0020
|
36 |
+
|
37 |
+
|
38 |
+
# Google Gemini
|
39 |
+
# https://ai.google.dev/pricing
|
40 |
+
GEMINI_PRO:
|
41 |
+
in: 0.00025
|
42 |
+
out: 0.0005
|
43 |
+
|
44 |
+
|
45 |
+
# Google PaLM 2 (text-bison, text-unicorn)
|
46 |
+
# https://cloud.google.com/vertex-ai/docs/generative-ai/pricing
|
47 |
+
PALM2_TU_1:
|
48 |
+
in: 0.00025
|
49 |
+
out: 0.00020
|
50 |
+
PALM2_TB_1:
|
51 |
+
in: 0.00025
|
52 |
+
out: 0.0005
|
53 |
+
PALM2_TB_2:
|
54 |
+
in: 0.00025
|
55 |
+
out: 0.0005
|
56 |
+
|
57 |
+
|
58 |
+
# Mistral AI
|
59 |
+
# https://docs.mistral.ai/platform/pricing/
|
60 |
+
MISTRAL_MEDIUM:
|
61 |
+
in: 0.00250
|
62 |
+
out: 0.00750
|
63 |
+
MISTRAL_SMALL:
|
64 |
+
in: 0.00060
|
65 |
+
out: 0.00180
|
66 |
+
MISTRAL_TINY:
|
67 |
+
in: 0.00014
|
68 |
+
out: 0.00042
|
69 |
+
|
70 |
+
|
71 |
+
################
|
72 |
+
# Local Models
|
73 |
+
################
|
74 |
+
LOCAL_MIXTRAL_8X7B_INSTRUCT_V01:
|
75 |
in: 0.0
|
76 |
out: 0.0
|
77 |
+
LOCAL_MISTRAL_7B_INSTRUCT_V02:
|
78 |
+
in: 0.0
|
79 |
+
out: 0.0
|
80 |
+
LOCAL_CPU_MISTRAL_7B_INSTRUCT_V02_GGUF:
|
81 |
+
in: 0.0
|
82 |
+
out: 0.0
|
app.py
CHANGED
The diff for this file is too large to render.
See raw diff
|
|
custom_prompts/required_structure.yaml
DELETED
@@ -1,65 +0,0 @@
|
|
1 |
-
prompt_author: unknown
|
2 |
-
prompt_author_institution: unknown
|
3 |
-
prompt_description: unknown
|
4 |
-
LLM: gpt
|
5 |
-
instructions: '1. Refactor the unstructured OCR text into a dictionary based on the
|
6 |
-
JSON structure outlined below.
|
7 |
-
|
8 |
-
2. You should map the unstructured OCR text to the appropriate JSON key and then
|
9 |
-
populate the field based on its rules.
|
10 |
-
|
11 |
-
3. Some JSON key fields are permitted to remain empty if the corresponding information
|
12 |
-
is not found in the unstructured OCR text.
|
13 |
-
|
14 |
-
4. Ignore any information in the OCR text that doesn''t fit into the defined JSON
|
15 |
-
structure.
|
16 |
-
|
17 |
-
5. Duplicate dictionary fields are not allowed.
|
18 |
-
|
19 |
-
6. Ensure that all JSON keys are in lowercase.
|
20 |
-
|
21 |
-
7. Ensure that new JSON field values follow sentence case capitalization.
|
22 |
-
|
23 |
-
8. Ensure all key-value pairs in the JSON dictionary strictly adhere to the format
|
24 |
-
and data types specified in the template.
|
25 |
-
|
26 |
-
9. Ensure the output JSON string is valid JSON format. It should not have trailing
|
27 |
-
commas or unquoted keys.
|
28 |
-
|
29 |
-
10. Only return a JSON dictionary represented as a string. You should not explain
|
30 |
-
your answer.'
|
31 |
-
json_formatting_instructions: "The next section of instructions outlines how to format\
|
32 |
-
\ the JSON dictionary. The keys are the same as those of the final formatted JSON\
|
33 |
-
\ object.\nFor each key there is a format requirement that specifies how to transcribe\
|
34 |
-
\ the information for that key. \nThe possible formatting options are:\n1. \"verbatim\
|
35 |
-
\ transcription\" - field is populated with verbatim text from the unformatted OCR.\n\
|
36 |
-
2. \"spell check transcription\" - field is populated with spelling corrected text\
|
37 |
-
\ from the unformatted OCR.\n3. \"boolean yes no\" - field is populated with only\
|
38 |
-
\ yes or no.\n4. \"boolean 1 0\" - field is populated with only 1 or 0.\n5. \"integer\"\
|
39 |
-
\ - field is populated with only an integer.\n6. \"[list]\" - field is populated\
|
40 |
-
\ from one of the values in the list.\n7. \"yyyy-mm-dd\" - field is populated with\
|
41 |
-
\ a date in the format year-month-day.\nThe desired null value is also given. Populate\
|
42 |
-
\ the field with the null value of the information for that key is not present in\
|
43 |
-
\ the unformatted OCR text."
|
44 |
-
mapping:
|
45 |
-
# Add column names to the desired category. This is used to map the VV Editor.
|
46 |
-
COLLECTING: []
|
47 |
-
GEOGRAPHY: []
|
48 |
-
LOCALITY: []
|
49 |
-
MISCELLANEOUS: []
|
50 |
-
TAXONOMY:
|
51 |
-
- catalog_number
|
52 |
-
rules:
|
53 |
-
Dictionary:
|
54 |
-
# Manually add rows here. You MUST keep 'catalog_number' unchanged. Use 'catalog_number' as a guide for adding more columns.
|
55 |
-
# The only values allowed in the 'format' key are those outlines above in the 'json_formatting_instructions' section.
|
56 |
-
# If you want an empty cell by default, use '' for the 'null_value'.
|
57 |
-
catalog_number:
|
58 |
-
description: The barcode identifier, typically a number with at least 6 digits,
|
59 |
-
but fewer than 30 digits.
|
60 |
-
format: verbatim transcription
|
61 |
-
null_value: ''
|
62 |
-
# Do not change or remove below. This is required for some LLMs
|
63 |
-
SpeciesName:
|
64 |
-
taxonomy:
|
65 |
-
- Genus_species
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
custom_prompts/version_2.yaml
DELETED
@@ -1,232 +0,0 @@
|
|
1 |
-
prompt_author: Will Weaver
|
2 |
-
prompt_author_institution: UM
|
3 |
-
prompt_description: Basic prompt used by the University of Michigan. Designed to be a starting point for more complex prompts.
|
4 |
-
LLM: gpt
|
5 |
-
instructions: '1. Refactor the unstructured OCR text into a dictionary based on the
|
6 |
-
JSON structure outlined below.
|
7 |
-
|
8 |
-
2. You should map the unstructured OCR text to the appropriate JSON key and then
|
9 |
-
populate the field based on its rules.
|
10 |
-
|
11 |
-
3. Some JSON key fields are permitted to remain empty if the corresponding information
|
12 |
-
is not found in the unstructured OCR text.
|
13 |
-
|
14 |
-
4. Ignore any information in the OCR text that doesn''t fit into the defined JSON
|
15 |
-
structure.
|
16 |
-
|
17 |
-
5. Duplicate dictionary fields are not allowed.
|
18 |
-
|
19 |
-
6. Ensure that all JSON keys are in lowercase.
|
20 |
-
|
21 |
-
7. Ensure that new JSON field values follow sentence case capitalization.
|
22 |
-
|
23 |
-
8. Ensure all key-value pairs in the JSON dictionary strictly adhere to the format
|
24 |
-
and data types specified in the template.
|
25 |
-
|
26 |
-
9. Ensure the output JSON string is valid JSON format. It should not have trailing
|
27 |
-
commas or unquoted keys.
|
28 |
-
|
29 |
-
10. Only return a JSON dictionary represented as a string. You should not explain
|
30 |
-
your answer.'
|
31 |
-
json_formatting_instructions: "The next section of instructions outlines how to format\
|
32 |
-
\ the JSON dictionary. The keys are the same as those of the final formatted JSON\
|
33 |
-
\ object.\nFor each key there is a format requirement that specifies how to transcribe\
|
34 |
-
\ the information for that key. \nThe possible formatting options are:\n1. \"verbatim\
|
35 |
-
\ transcription\" - field is populated with verbatim text from the unformatted OCR.\n\
|
36 |
-
2. \"spell check transcription\" - field is populated with spelling corrected text\
|
37 |
-
\ from the unformatted OCR.\n3. \"boolean yes no\" - field is populated with only\
|
38 |
-
\ yes or no.\n4. \"boolean 1 0\" - field is populated with only 1 or 0.\n5. \"integer\"\
|
39 |
-
\ - field is populated with only an integer.\n6. \"[list]\" - field is populated\
|
40 |
-
\ from one of the values in the list.\n7. \"yyyy-mm-dd\" - field is populated with\
|
41 |
-
\ a date in the format year-month-day.\nThe desired null value is also given. Populate\
|
42 |
-
\ the field with the null value of the information for that key is not present in\
|
43 |
-
\ the unformatted OCR text."
|
44 |
-
mapping:
|
45 |
-
COLLECTING:
|
46 |
-
- collectors
|
47 |
-
- collector_number
|
48 |
-
- determined_by
|
49 |
-
- multiple_names
|
50 |
-
- verbatim_date
|
51 |
-
- date
|
52 |
-
- end_date
|
53 |
-
GEOGRAPHY:
|
54 |
-
- country
|
55 |
-
- state
|
56 |
-
- county
|
57 |
-
- min_elevation
|
58 |
-
- max_elevation
|
59 |
-
- elevation_units
|
60 |
-
LOCALITY:
|
61 |
-
- locality_name
|
62 |
-
- verbatim_coordinates
|
63 |
-
- decimal_coordinates
|
64 |
-
- datum
|
65 |
-
- plant_description
|
66 |
-
- cultivated
|
67 |
-
- habitat
|
68 |
-
MISCELLANEOUS: []
|
69 |
-
TAXONOMY:
|
70 |
-
- catalog_number
|
71 |
-
- genus
|
72 |
-
- species
|
73 |
-
- subspecies
|
74 |
-
- variety
|
75 |
-
- forma
|
76 |
-
rules:
|
77 |
-
Dictionary:
|
78 |
-
catalog_number:
|
79 |
-
description: The barcode identifier, typically a number with at least 6 digits,
|
80 |
-
but fewer than 30 digits.
|
81 |
-
format: verbatim transcription
|
82 |
-
null_value: ''
|
83 |
-
collector_number:
|
84 |
-
description: Unique identifier or number that denotes the specific collecting
|
85 |
-
event and associated with the collector.
|
86 |
-
format: verbatim transcription
|
87 |
-
null_value: s.n.
|
88 |
-
collectors:
|
89 |
-
description: Full name(s) of the individual(s) responsible for collecting the
|
90 |
-
specimen. When multiple collectors are involved, their names should be separated
|
91 |
-
by commas.
|
92 |
-
format: verbatim transcription
|
93 |
-
null_value: not present
|
94 |
-
country:
|
95 |
-
description: Country that corresponds to the current geographic location of
|
96 |
-
collection. Capitalize first letter of each word. If abbreviation is given
|
97 |
-
populate field with the full spelling of the country's name.
|
98 |
-
format: spell check transcription
|
99 |
-
null_value: ''
|
100 |
-
county:
|
101 |
-
description: Administrative division 2 that corresponds to the current geographic
|
102 |
-
location of collection; capitalize first letter of each word. Administrative
|
103 |
-
division 2 is equivalent to a U.S. county, parish, borough.
|
104 |
-
format: spell check transcription
|
105 |
-
null_value: ''
|
106 |
-
cultivated:
|
107 |
-
description: Cultivated plants are intentionally grown by humans. In text descriptions,
|
108 |
-
look for planting dates, garden locations, ornamental, cultivar names, garden,
|
109 |
-
or farm to indicate cultivated plant.
|
110 |
-
format: boolean yes no
|
111 |
-
null_value: ''
|
112 |
-
date:
|
113 |
-
description: 'Date the specimen was collected formatted as year-month-day. If
|
114 |
-
specific components of the date are unknown, they should be replaced with
|
115 |
-
zeros. Examples: ''0000-00-00'' if the entire date is unknown, ''YYYY-00-00''
|
116 |
-
if only the year is known, and ''YYYY-MM-00'' if year and month are known
|
117 |
-
but day is not.'
|
118 |
-
format: yyyy-mm-dd
|
119 |
-
null_value: ''
|
120 |
-
datum:
|
121 |
-
description: Datum of location coordinates. Possible values are include in the
|
122 |
-
format list. Leave field blank if unclear. [WGS84, WGS72, WGS66, WGS60, NAD83,
|
123 |
-
NAD27, OSGB36, ETRS89, ED50, GDA94, JGD2011, Tokyo97, KGD2002, TWD67, TWD97,
|
124 |
-
BJS54, XAS80, GCJ-02, BD-09, PZ-90.11, GTRF, CGCS2000, ITRF88, ITRF89, ITRF90,
|
125 |
-
ITRF91, ITRF92, ITRF93, ITRF94, ITRF96, ITRF97, ITRF2000, ITRF2005, ITRF2008,
|
126 |
-
ITRF2014, Hong Kong Principal Datum, SAD69]
|
127 |
-
format: '[list]'
|
128 |
-
null_value: ''
|
129 |
-
decimal_coordinates:
|
130 |
-
description: Correct and convert the verbatim location coordinates to conform
|
131 |
-
with the decimal degrees GPS coordinate format.
|
132 |
-
format: spell check transcription
|
133 |
-
null_value: ''
|
134 |
-
determined_by:
|
135 |
-
description: Full name of the individual responsible for determining the taxanomic
|
136 |
-
name of the specimen. Sometimes the name will be near to the characters 'det'
|
137 |
-
to denote determination. This name may be isolated from other names in the
|
138 |
-
unformatted OCR text.
|
139 |
-
format: verbatim transcription
|
140 |
-
null_value: ''
|
141 |
-
elevation_units:
|
142 |
-
description: 'Elevation units must be meters. If min_elevation field is populated,
|
143 |
-
then elevation_units: ''m''. Otherwise elevation_units: ''''.'
|
144 |
-
format: spell check transcription
|
145 |
-
null_value: ''
|
146 |
-
end_date:
|
147 |
-
description: 'If a date range is provided, this represents the later or ending
|
148 |
-
date of the collection period, formatted as year-month-day. If specific components
|
149 |
-
of the date are unknown, they should be replaced with zeros. Examples: ''0000-00-00''
|
150 |
-
if the entire end date is unknown, ''YYYY-00-00'' if only the year of the
|
151 |
-
end date is known, and ''YYYY-MM-00'' if year and month of the end date are
|
152 |
-
known but the day is not.'
|
153 |
-
format: yyyy-mm-dd
|
154 |
-
null_value: ''
|
155 |
-
forma:
|
156 |
-
description: Taxonomic determination to form (f.).
|
157 |
-
format: verbatim transcription
|
158 |
-
null_value: ''
|
159 |
-
genus:
|
160 |
-
description: Taxonomic determination to genus. Genus must be capitalized. If
|
161 |
-
genus is not present use the taxonomic family name followed by the word 'indet'.
|
162 |
-
format: verbatim transcription
|
163 |
-
null_value: ''
|
164 |
-
habitat:
|
165 |
-
description: Description of a plant's habitat or the location where the specimen
|
166 |
-
was collected. Ignore descriptions of the plant itself.
|
167 |
-
format: verbatim transcription
|
168 |
-
null_value: ''
|
169 |
-
locality_name:
|
170 |
-
description: Description of geographic location, landscape, landmarks, regional
|
171 |
-
features, nearby places, or any contextual information aiding in pinpointing
|
172 |
-
the exact origin or site of the specimen.
|
173 |
-
format: verbatim transcription
|
174 |
-
null_value: ''
|
175 |
-
max_elevation:
|
176 |
-
description: Maximum elevation or altitude in meters. If only one elevation
|
177 |
-
is present, then max_elevation should be set to the null_value. Only if units
|
178 |
-
are explicit then convert from feet ('ft' or 'ft.' or 'feet') to meters ('m'
|
179 |
-
or 'm.' or 'meters'). Round to integer.
|
180 |
-
format: integer
|
181 |
-
null_value: ''
|
182 |
-
min_elevation:
|
183 |
-
description: Minimum elevation or altitude in meters. Only if units are explicit
|
184 |
-
then convert from feet ('ft' or 'ft.' or 'feet') to meters ('m' or 'm.' or
|
185 |
-
'meters'). Round to integer.
|
186 |
-
format: integer
|
187 |
-
null_value: ''
|
188 |
-
multiple_names:
|
189 |
-
description: Indicate whether multiple people or collector names are present
|
190 |
-
in the unformatted OCR text. If you see more than one person's name the value
|
191 |
-
is 'yes'; otherwise the value is 'no'.
|
192 |
-
format: boolean yes no
|
193 |
-
null_value: ''
|
194 |
-
plant_description:
|
195 |
-
description: Description of plant features such as leaf shape, size, color,
|
196 |
-
stem texture, height, flower structure, scent, fruit or seed characteristics,
|
197 |
-
root system type, overall growth habit and form, any notable aroma or secretions,
|
198 |
-
presence of hairs or bristles, and any other distinguishing morphological
|
199 |
-
or physiological characteristics.
|
200 |
-
format: verbatim transcription
|
201 |
-
null_value: ''
|
202 |
-
species:
|
203 |
-
description: Taxonomic determination to species, do not capitalize species.
|
204 |
-
format: verbatim transcription
|
205 |
-
null_value: ''
|
206 |
-
state:
|
207 |
-
description: Administrative division 1 that corresponds to the current geographic
|
208 |
-
location of collection. Capitalize first letter of each word. Administrative
|
209 |
-
division 1 is equivalent to a U.S. State.
|
210 |
-
format: spell check transcription
|
211 |
-
null_value: ''
|
212 |
-
subspecies:
|
213 |
-
description: Taxonomic determination to subspecies (subsp.).
|
214 |
-
format: verbatim transcription
|
215 |
-
null_value: ''
|
216 |
-
variety:
|
217 |
-
description: Taxonomic determination to variety (var).
|
218 |
-
format: verbatim transcription
|
219 |
-
null_value: ''
|
220 |
-
verbatim_coordinates:
|
221 |
-
description: Verbatim location coordinates as they appear on the label. Do not
|
222 |
-
convert formats. Possible coordinate types are one of [Lat, Long, UTM, TRS].
|
223 |
-
format: verbatim transcription
|
224 |
-
null_value: ''
|
225 |
-
verbatim_date:
|
226 |
-
description: Date of collection exactly as it appears on the label. Do not change
|
227 |
-
the format or correct typos.
|
228 |
-
format: verbatim transcription
|
229 |
-
null_value: s.d.
|
230 |
-
SpeciesName:
|
231 |
-
taxonomy:
|
232 |
-
- Genus_species
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
custom_prompts/version_2_OSU.yaml
DELETED
@@ -1,233 +0,0 @@
|
|
1 |
-
prompt_author: Will Weaver
|
2 |
-
prompt_author_institution: UM
|
3 |
-
prompt_description: Version 2 slightly modified for OSU, but still unfinished
|
4 |
-
LLM: gpt
|
5 |
-
instructions: '1. Refactor the unstructured OCR text into a dictionary based on the
|
6 |
-
JSON structure outlined below.
|
7 |
-
|
8 |
-
2. You should map the unstructured OCR text to the appropriate JSON key and then
|
9 |
-
populate the field based on its rules.
|
10 |
-
|
11 |
-
3. Some JSON key fields are permitted to remain empty if the corresponding information
|
12 |
-
is not found in the unstructured OCR text.
|
13 |
-
|
14 |
-
4. Ignore any information in the OCR text that doesn''t fit into the defined JSON
|
15 |
-
structure.
|
16 |
-
|
17 |
-
5. Duplicate dictionary fields are not allowed.
|
18 |
-
|
19 |
-
6. Ensure that all JSON keys are in lowercase.
|
20 |
-
|
21 |
-
7. Ensure that new JSON field values follow sentence case capitalization.
|
22 |
-
|
23 |
-
8. Ensure all key-value pairs in the JSON dictionary strictly adhere to the format
|
24 |
-
and data types specified in the template.
|
25 |
-
|
26 |
-
9. Ensure the output JSON string is valid JSON format. It should not have trailing
|
27 |
-
commas or unquoted keys.
|
28 |
-
|
29 |
-
10. Only return a JSON dictionary represented as a string. You should not explain
|
30 |
-
your answer.'
|
31 |
-
json_formatting_instructions: "The next section of instructions outlines how to format\
|
32 |
-
\ the JSON dictionary. The keys are the same as those of the final formatted JSON\
|
33 |
-
\ object.\nFor each key there is a format requirement that specifies how to transcribe\
|
34 |
-
\ the information for that key. \nThe possible formatting options are:\n1. \"verbatim\
|
35 |
-
\ transcription\" - field is populated with verbatim text from the unformatted OCR.\n\
|
36 |
-
2. \"spell check transcription\" - field is populated with spelling corrected text\
|
37 |
-
\ from the unformatted OCR.\n3. \"boolean yes no\" - field is populated with only\
|
38 |
-
\ yes or no.\n4. \"boolean 1 0\" - field is populated with only 1 or 0.\n5. \"integer\"\
|
39 |
-
\ - field is populated with only an integer.\n6. \"[list]\" - field is populated\
|
40 |
-
\ from one of the values in the list.\n7. \"yyyy-mm-dd\" - field is populated with\
|
41 |
-
\ a date in the format year-month-day.\nThe desired null value is also given. Populate\
|
42 |
-
\ the field with the null value of the information for that key is not present in\
|
43 |
-
\ the unformatted OCR text."
|
44 |
-
mapping:
|
45 |
-
COLLECTING:
|
46 |
-
- collectors
|
47 |
-
- collector_number
|
48 |
-
- determined_by
|
49 |
-
- multiple_names
|
50 |
-
- verbatim_date
|
51 |
-
- date
|
52 |
-
- end_date
|
53 |
-
GEOGRAPHY:
|
54 |
-
- country
|
55 |
-
- state
|
56 |
-
- county
|
57 |
-
- min_elevation
|
58 |
-
- max_elevation
|
59 |
-
- elevation_units
|
60 |
-
LOCALITY:
|
61 |
-
- locality_name
|
62 |
-
- verbatim_coordinates
|
63 |
-
- decimal_coordinates
|
64 |
-
- datum
|
65 |
-
- plant_description
|
66 |
-
- cultivated
|
67 |
-
- habitat
|
68 |
-
MISCELLANEOUS: []
|
69 |
-
TAXONOMY:
|
70 |
-
- catalog_number
|
71 |
-
- genus
|
72 |
-
- species
|
73 |
-
- subspecies
|
74 |
-
- variety
|
75 |
-
- forma
|
76 |
-
rules:
|
77 |
-
Dictionary:
|
78 |
-
catalog_number:
|
79 |
-
description: The barcode identifier, typically a number with at least 6 digits,
|
80 |
-
but fewer than 30 digits.
|
81 |
-
format: verbatim transcription
|
82 |
-
null_value: ''
|
83 |
-
collector_number:
|
84 |
-
description: Unique identifier or number that denotes the specific collecting
|
85 |
-
event and associated with the collector.
|
86 |
-
format: verbatim transcription
|
87 |
-
null_value: s.n.
|
88 |
-
collectors:
|
89 |
-
description: Full name(s) of the individual(s) responsible for collecting the
|
90 |
-
specimen. When multiple collectors are involved, their names should be separated
|
91 |
-
by commas.
|
92 |
-
format: verbatim transcription
|
93 |
-
null_value: not present
|
94 |
-
country:
|
95 |
-
description: Country that corresponds to the current geographic location of
|
96 |
-
collection. Capitalize first letter of each word. If abbreviation is given
|
97 |
-
populate field with the full spelling of the country's name.
|
98 |
-
format: spell check transcription
|
99 |
-
null_value: ''
|
100 |
-
county:
|
101 |
-
description: Administrative division 2 that corresponds to the current geographic
|
102 |
-
location of collection; capitalize first letter of each word. Administrative
|
103 |
-
division 2 is equivalent to a U.S. county, parish, borough.
|
104 |
-
format: spell check transcription
|
105 |
-
null_value: ''
|
106 |
-
cultivated:
|
107 |
-
description: Cultivated plants are intentionally grown by humans. In text descriptions,
|
108 |
-
look for planting dates, garden locations, ornamental, cultivar names, garden,
|
109 |
-
or farm to indicate cultivated plant. The value 1 indicates that the specimen
|
110 |
-
was cultivated, the value zero otherwise.
|
111 |
-
format: boolean 1 0
|
112 |
-
null_value: '0'
|
113 |
-
date:
|
114 |
-
description: 'Date the specimen was collected formatted as year-month-day. If
|
115 |
-
specific components of the date are unknown, they should be replaced with
|
116 |
-
zeros. Examples: ''0000-00-00'' if the entire date is unknown, ''YYYY-00-00''
|
117 |
-
if only the year is known, and ''YYYY-MM-00'' if year and month are known
|
118 |
-
but day is not.'
|
119 |
-
format: yyyy-mm-dd
|
120 |
-
null_value: ''
|
121 |
-
datum:
|
122 |
-
description: Datum of location coordinates. Possible values are include in the
|
123 |
-
format list. Leave field blank if unclear. [WGS84, WGS72, WGS66, WGS60, NAD83,
|
124 |
-
NAD27, OSGB36, ETRS89, ED50, GDA94, JGD2011, Tokyo97, KGD2002, TWD67, TWD97,
|
125 |
-
BJS54, XAS80, GCJ-02, BD-09, PZ-90.11, GTRF, CGCS2000, ITRF88, ITRF89, ITRF90,
|
126 |
-
ITRF91, ITRF92, ITRF93, ITRF94, ITRF96, ITRF97, ITRF2000, ITRF2005, ITRF2008,
|
127 |
-
ITRF2014, Hong Kong Principal Datum, SAD69]
|
128 |
-
format: '[list]'
|
129 |
-
null_value: ''
|
130 |
-
decimal_coordinates:
|
131 |
-
description: Correct and convert the verbatim location coordinates to conform
|
132 |
-
with the decimal degrees GPS coordinate format.
|
133 |
-
format: spell check transcription
|
134 |
-
null_value: ''
|
135 |
-
determined_by:
|
136 |
-
description: Full name of the individual responsible for determining the taxanomic
|
137 |
-
name of the specimen. Sometimes the name will be near to the characters 'det'
|
138 |
-
to denote determination. This name may be isolated from other names in the
|
139 |
-
unformatted OCR text.
|
140 |
-
format: verbatim transcription
|
141 |
-
null_value: ''
|
142 |
-
elevation_units:
|
143 |
-
description: 'Elevation units must be meters. If min_elevation field is populated,
|
144 |
-
then elevation_units: ''m''. Otherwise elevation_units: ''''.'
|
145 |
-
format: spell check transcription
|
146 |
-
null_value: ''
|
147 |
-
end_date:
|
148 |
-
description: 'If a date range is provided, this represents the later or ending
|
149 |
-
date of the collection period, formatted as year-month-day. If specific components
|
150 |
-
of the date are unknown, they should be replaced with zeros. Examples: ''0000-00-00''
|
151 |
-
if the entire end date is unknown, ''YYYY-00-00'' if only the year of the
|
152 |
-
end date is known, and ''YYYY-MM-00'' if year and month of the end date are
|
153 |
-
known but the day is not.'
|
154 |
-
format: yyyy-mm-dd
|
155 |
-
null_value: ''
|
156 |
-
forma:
|
157 |
-
description: Taxonomic determination to form (f.).
|
158 |
-
format: verbatim transcription
|
159 |
-
null_value: ''
|
160 |
-
genus:
|
161 |
-
description: Taxonomic determination to genus. Genus must be capitalized. If
|
162 |
-
genus is not present use the taxonomic family name followed by the word 'indet'.
|
163 |
-
format: verbatim transcription
|
164 |
-
null_value: ''
|
165 |
-
habitat:
|
166 |
-
description: Description of a plant's habitat or the location where the specimen
|
167 |
-
was collected. Ignore descriptions of the plant itself.
|
168 |
-
format: verbatim transcription
|
169 |
-
null_value: ''
|
170 |
-
locality_name:
|
171 |
-
description: Description of geographic location, landscape, landmarks, regional
|
172 |
-
features, nearby places, or any contextual information aiding in pinpointing
|
173 |
-
the exact origin or site of the specimen.
|
174 |
-
format: verbatim transcription
|
175 |
-
null_value: ''
|
176 |
-
max_elevation:
|
177 |
-
description: Maximum elevation or altitude in meters. If only one elevation
|
178 |
-
is present, then max_elevation should be set to the null_value. Only if units
|
179 |
-
are explicit then convert from feet ('ft' or 'ft.' or 'feet') to meters ('m'
|
180 |
-
or 'm.' or 'meters'). Round to integer.
|
181 |
-
format: integer
|
182 |
-
null_value: ''
|
183 |
-
min_elevation:
|
184 |
-
description: Minimum elevation or altitude in meters. Only if units are explicit
|
185 |
-
then convert from feet ('ft' or 'ft.' or 'feet') to meters ('m' or 'm.' or
|
186 |
-
'meters'). Round to integer.
|
187 |
-
format: integer
|
188 |
-
null_value: ''
|
189 |
-
multiple_names:
|
190 |
-
description: Indicate whether multiple people or collector names are present
|
191 |
-
in the unformatted OCR text. If you see more than one person's name the value
|
192 |
-
is 'yes'; otherwise the value is 'no'.
|
193 |
-
format: boolean yes no
|
194 |
-
null_value: ''
|
195 |
-
plant_description:
|
196 |
-
description: Description of plant features such as leaf shape, size, color,
|
197 |
-
stem texture, height, flower structure, scent, fruit or seed characteristics,
|
198 |
-
root system type, overall growth habit and form, any notable aroma or secretions,
|
199 |
-
presence of hairs or bristles, and any other distinguishing morphological
|
200 |
-
or physiological characteristics.
|
201 |
-
format: verbatim transcription
|
202 |
-
null_value: ''
|
203 |
-
species:
|
204 |
-
description: Taxonomic determination to species, do not capitalize species.
|
205 |
-
format: verbatim transcription
|
206 |
-
null_value: ''
|
207 |
-
state:
|
208 |
-
description: Administrative division 1 that corresponds to the current geographic
|
209 |
-
location of collection. Capitalize first letter of each word. Administrative
|
210 |
-
division 1 is equivalent to a U.S. State.
|
211 |
-
format: spell check transcription
|
212 |
-
null_value: ''
|
213 |
-
subspecies:
|
214 |
-
description: Taxonomic determination to subspecies (subsp.).
|
215 |
-
format: verbatim transcription
|
216 |
-
null_value: ''
|
217 |
-
variety:
|
218 |
-
description: Taxonomic determination to variety (var).
|
219 |
-
format: verbatim transcription
|
220 |
-
null_value: ''
|
221 |
-
verbatim_coordinates:
|
222 |
-
description: Verbatim location coordinates as they appear on the label. Do not
|
223 |
-
convert formats. Possible coordinate types are one of [Lat, Long, UTM, TRS].
|
224 |
-
format: verbatim transcription
|
225 |
-
null_value: ''
|
226 |
-
verbatim_date:
|
227 |
-
description: Date of collection exactly as it appears on the label. Do not change
|
228 |
-
the format or correct typos.
|
229 |
-
format: verbatim transcription
|
230 |
-
null_value: s.d.
|
231 |
-
SpeciesName:
|
232 |
-
taxonomy:
|
233 |
-
- Genus_species
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
demo/NY_1928185102_Heliotropiaceae_Heliotropium_indicum.jpg
ADDED
Git LFS Details
|
demo/ba/ba2.jpg
ADDED
Git LFS Details
|
demo/ba/ocr.jpg
ADDED
Git LFS Details
|
demo/{demo_images → demo_gallery}/UM_1807464860_Phellinaceae_Phelline_dumbeensis.jpg
RENAMED
File without changes
|
demo/demo_images/MICH_29667680_Hypericaceae_Hypericum_prolificum.jpg
ADDED
Git LFS Details
|
requirements.txt
CHANGED
Binary files a/requirements.txt and b/requirements.txt differ
|
|
run_VoucherVision.py
CHANGED
@@ -23,9 +23,13 @@ if __name__ == "__main__":
|
|
23 |
sys.argv = [
|
24 |
"streamlit",
|
25 |
"run",
|
26 |
-
resolve_path(os.path.join(dir_home,"
|
|
|
27 |
"--global.developmentMode=false",
|
28 |
-
"--server.port=
|
29 |
-
|
|
|
|
|
|
|
30 |
]
|
31 |
sys.exit(stcli.main())
|
|
|
23 |
sys.argv = [
|
24 |
"streamlit",
|
25 |
"run",
|
26 |
+
resolve_path(os.path.join(dir_home,"app.py")),
|
27 |
+
# resolve_path(os.path.join(dir_home,"vouchervision", "VoucherVision_GUI.py")),
|
28 |
"--global.developmentMode=false",
|
29 |
+
# "--server.port=8545",
|
30 |
+
"--server.port=8546",
|
31 |
+
# Toggle below for HF vs Local
|
32 |
+
"--is_hf=1",
|
33 |
+
# "--is_hf=0",
|
34 |
]
|
35 |
sys.exit(stcli.main())
|
settings/bottom.yaml
ADDED
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
leafmachine:
|
2 |
+
LLM_version: Azure GPT 4 Turbo 1106-preview
|
3 |
+
archival_component_detector:
|
4 |
+
detector_iteration: PREP_final
|
5 |
+
detector_type: Archival_Detector
|
6 |
+
detector_version: PREP_final
|
7 |
+
detector_weights: best.pt
|
8 |
+
do_save_prediction_overlay_images: true
|
9 |
+
ignore_objects_for_overlay: []
|
10 |
+
minimum_confidence_threshold: 0.5
|
11 |
+
cropped_components:
|
12 |
+
binarize_labels: false
|
13 |
+
binarize_labels_skeletonize: false
|
14 |
+
do_save_cropped_annotations: true
|
15 |
+
save_cropped_annotations:
|
16 |
+
- label
|
17 |
+
- barcode
|
18 |
+
save_per_annotation_class: true
|
19 |
+
save_per_image: false
|
20 |
+
data:
|
21 |
+
do_apply_conversion_factor: false
|
22 |
+
include_darwin_core_data_from_combined_file: false
|
23 |
+
save_individual_csv_files_landmarks: false
|
24 |
+
save_individual_csv_files_measurements: false
|
25 |
+
save_individual_csv_files_rulers: false
|
26 |
+
save_individual_efd_files: false
|
27 |
+
save_json_measurements: false
|
28 |
+
save_json_rulers: false
|
29 |
+
do:
|
30 |
+
check_for_corrupt_images_make_vertical: true
|
31 |
+
check_for_illegal_filenames: false
|
32 |
+
do_create_OCR_helper_image: false
|
33 |
+
logging:
|
34 |
+
log_level: null
|
35 |
+
modules:
|
36 |
+
specimen_crop: true
|
37 |
+
overlay:
|
38 |
+
alpha_transparency_archival: 0.3
|
39 |
+
alpha_transparency_plant: 0
|
40 |
+
alpha_transparency_seg_partial_leaf: 0.3
|
41 |
+
alpha_transparency_seg_whole_leaf: 0.4
|
42 |
+
ignore_archival_detections_classes: []
|
43 |
+
ignore_landmark_classes: []
|
44 |
+
ignore_plant_detections_classes:
|
45 |
+
- leaf_whole
|
46 |
+
- specimen
|
47 |
+
line_width_archival: 12
|
48 |
+
line_width_efd: 12
|
49 |
+
line_width_plant: 12
|
50 |
+
line_width_seg: 12
|
51 |
+
overlay_background_color: black
|
52 |
+
overlay_dpi: 300
|
53 |
+
save_overlay_to_jpgs: true
|
54 |
+
save_overlay_to_pdf: false
|
55 |
+
show_archival_detections: true
|
56 |
+
show_landmarks: true
|
57 |
+
show_plant_detections: true
|
58 |
+
show_segmentations: true
|
59 |
+
print:
|
60 |
+
optional_warnings: true
|
61 |
+
verbose: true
|
62 |
+
project:
|
63 |
+
OCR_option: both
|
64 |
+
batch_size: 500
|
65 |
+
build_new_embeddings_database: false
|
66 |
+
catalog_numerical_only: false
|
67 |
+
continue_run_from_partial_xlsx: ''
|
68 |
+
delete_all_temps: false
|
69 |
+
delete_temps_keep_VVE: false
|
70 |
+
dir_images_local: d:\Dropbox\VoucherVision\demo\demo_images
|
71 |
+
dir_output: C:\Users\Will\Downloads
|
72 |
+
do_use_trOCR: true
|
73 |
+
embeddings_database_name: SLTP_UM_AllAsiaMinimalInRegion
|
74 |
+
image_location: local
|
75 |
+
num_workers: 8
|
76 |
+
path_to_domain_knowledge_xlsx: d:\Dropbox\VoucherVision\domain_knowledge\SLTP_UM_AllAsiaMinimalInRegion.xlsx
|
77 |
+
prefix_removal: ''
|
78 |
+
prompt_version: SLTPvA_long.yaml
|
79 |
+
run_name: test
|
80 |
+
suffix_removal: ''
|
81 |
+
use_domain_knowledge: false
|
82 |
+
use_RGB_label_images: true
|
vouchervision/API_validation.py
ADDED
@@ -0,0 +1,224 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os, io, openai, vertexai
|
2 |
+
from mistralai.client import MistralClient
|
3 |
+
from mistralai.models.chat_completion import ChatMessage
|
4 |
+
from langchain.schema import HumanMessage
|
5 |
+
from langchain_openai import AzureChatOpenAI
|
6 |
+
from vertexai.language_models import TextGenerationModel
|
7 |
+
from vertexai.preview.generative_models import GenerativeModel
|
8 |
+
from google.cloud import vision
|
9 |
+
from datetime import datetime
|
10 |
+
|
11 |
+
|
12 |
+
|
13 |
+
class APIvalidation:
|
14 |
+
|
15 |
+
def __init__(self, cfg_private, dir_home) -> None:
|
16 |
+
self.cfg_private = cfg_private
|
17 |
+
self.dir_home = dir_home
|
18 |
+
self.formatted_date = self.get_formatted_date()
|
19 |
+
|
20 |
+
def get_formatted_date(self):
|
21 |
+
# Get the current date
|
22 |
+
current_date = datetime.now()
|
23 |
+
|
24 |
+
# Format the date as "Month day, year" (e.g., "January 23, 2024")
|
25 |
+
formatted_date = current_date.strftime("%B %d, %Y")
|
26 |
+
|
27 |
+
return formatted_date
|
28 |
+
|
29 |
+
|
30 |
+
def has_API_key(self, val):
|
31 |
+
if val:
|
32 |
+
return True
|
33 |
+
else:
|
34 |
+
return False
|
35 |
+
|
36 |
+
def check_openai_api_key(self):
|
37 |
+
openai.api_key = self.cfg_private['openai']['OPENAI_API_KEY']
|
38 |
+
try:
|
39 |
+
openai.models.list()
|
40 |
+
return True
|
41 |
+
except:
|
42 |
+
return False
|
43 |
+
|
44 |
+
def check_google_ocr_api_key(self):
|
45 |
+
# if os.path.exists(self.cfg_private['google_cloud']['path_json_file']):
|
46 |
+
# os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = self.cfg_private['google_cloud']['path_json_file']
|
47 |
+
# elif os.path.exists(self.cfg_private['google_cloud']['path_json_file_service_account2']):
|
48 |
+
# os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = self.cfg_private['google_cloud']['path_json_file_service_account2']
|
49 |
+
# else:
|
50 |
+
# return False
|
51 |
+
|
52 |
+
try:
|
53 |
+
logo_path = os.path.join(self.dir_home, 'img','logo.png')
|
54 |
+
client = vision.ImageAnnotatorClient()
|
55 |
+
with io.open(logo_path, 'rb') as image_file:
|
56 |
+
content = image_file.read()
|
57 |
+
image = vision.Image(content=content)
|
58 |
+
response = client.document_text_detection(image=image)
|
59 |
+
texts = response.text_annotations
|
60 |
+
normal_cleaned_text = texts[0].description if texts else None
|
61 |
+
if normal_cleaned_text:
|
62 |
+
return True
|
63 |
+
else:
|
64 |
+
return False
|
65 |
+
except:
|
66 |
+
return False
|
67 |
+
|
68 |
+
def check_azure_openai_api_key(self):
|
69 |
+
try:
|
70 |
+
# Initialize the Azure OpenAI client
|
71 |
+
model = AzureChatOpenAI(
|
72 |
+
deployment_name = 'gpt-35-turbo',#'gpt-35-turbo',
|
73 |
+
openai_api_version = self.cfg_private['openai_azure']['api_version'],
|
74 |
+
openai_api_key = self.cfg_private['openai_azure']['openai_api_key'],
|
75 |
+
azure_endpoint = self.cfg_private['openai_azure']['openai_api_base'],
|
76 |
+
openai_organization = self.cfg_private['openai_azure']['openai_organization'],
|
77 |
+
)
|
78 |
+
msg = HumanMessage(content="hello")
|
79 |
+
# self.llm_object.temperature = self.config.get('temperature')
|
80 |
+
response = model([msg])
|
81 |
+
|
82 |
+
# Check the response content (you might need to adjust this depending on how your AzureChatOpenAI class handles responses)
|
83 |
+
if response:
|
84 |
+
return True
|
85 |
+
else:
|
86 |
+
return False
|
87 |
+
|
88 |
+
except Exception as e: # Use a more specific exception if possible
|
89 |
+
return False
|
90 |
+
|
91 |
+
def check_mistral_api_key(self):
|
92 |
+
try:
|
93 |
+
# Initialize the Mistral Client with the API key
|
94 |
+
client = MistralClient(api_key=self.cfg_private['mistral']['mistral_key'])
|
95 |
+
|
96 |
+
# Create a simple message
|
97 |
+
messages = [ChatMessage(role="user", content="hello")]
|
98 |
+
|
99 |
+
# Send the message and get the response
|
100 |
+
chat_response = client.chat(
|
101 |
+
model="mistral-tiny",
|
102 |
+
messages=messages,
|
103 |
+
)
|
104 |
+
|
105 |
+
# Check if the response is valid (adjust this according to the actual response structure)
|
106 |
+
if chat_response and chat_response.choices:
|
107 |
+
return True
|
108 |
+
else:
|
109 |
+
return False
|
110 |
+
except Exception as e: # Replace with a more specific exception if possible
|
111 |
+
return False
|
112 |
+
|
113 |
+
def check_google_vertex_genai_api_key(self):
|
114 |
+
results = {"palm2": False, "gemini": False}
|
115 |
+
try:
|
116 |
+
# Assuming genai and vertexai are clients for Google services
|
117 |
+
os.environ["GOOGLE_API_KEY"] = self.cfg_private['google_palm']['google_palm_api']
|
118 |
+
# genai.configure(api_key=self.cfg_private['google_palm']['google_palm_api'])
|
119 |
+
vertexai.init(project= self.cfg_private['google_palm']['project_id'], location=self.cfg_private['google_palm']['location'])
|
120 |
+
|
121 |
+
try:
|
122 |
+
model = TextGenerationModel.from_pretrained("text-bison@001")
|
123 |
+
response = model.predict("Hello")
|
124 |
+
test_response_palm = response.text
|
125 |
+
# llm_palm = ChatGoogleGenerativeAI(model="text-bison@001")
|
126 |
+
# test_response_palm = llm_palm.invoke("Hello")
|
127 |
+
if test_response_palm:
|
128 |
+
results["palm2"] = True
|
129 |
+
except Exception as e:
|
130 |
+
pass
|
131 |
+
|
132 |
+
try:
|
133 |
+
model = GenerativeModel("gemini-pro")
|
134 |
+
response = model.generate_content("Hello")
|
135 |
+
test_response_gemini = response.text
|
136 |
+
# llm_gemini = ChatGoogleGenerativeAI(model="gemini-pro")
|
137 |
+
# test_response_gemini = llm_gemini.invoke("Hello")
|
138 |
+
if test_response_gemini:
|
139 |
+
results["gemini"] = True
|
140 |
+
except Exception as e:
|
141 |
+
pass
|
142 |
+
|
143 |
+
return results
|
144 |
+
except Exception as e: # Replace with a more specific exception if possible
|
145 |
+
return results
|
146 |
+
|
147 |
+
|
148 |
+
def report_api_key_status(self):
|
149 |
+
missing_keys = []
|
150 |
+
present_keys = []
|
151 |
+
|
152 |
+
# Check each key and add to the respective list
|
153 |
+
# OpenAI key check
|
154 |
+
if self.has_API_key(self.cfg_private['openai']['OPENAI_API_KEY']):
|
155 |
+
is_valid = self.check_openai_api_key()
|
156 |
+
if is_valid:
|
157 |
+
present_keys.append('OpenAI (Valid)')
|
158 |
+
else:
|
159 |
+
present_keys.append('OpenAI (Invalid)')
|
160 |
+
else:
|
161 |
+
missing_keys.append('OpenAI')
|
162 |
+
|
163 |
+
# Azure OpenAI key check
|
164 |
+
if self.has_API_key(self.cfg_private['openai_azure']['api_version']):
|
165 |
+
is_valid = self.check_azure_openai_api_key()
|
166 |
+
if is_valid:
|
167 |
+
present_keys.append('Azure OpenAI (Valid)')
|
168 |
+
else:
|
169 |
+
present_keys.append('Azure OpenAI (Invalid)')
|
170 |
+
else:
|
171 |
+
missing_keys.append('Azure OpenAI')
|
172 |
+
|
173 |
+
# Google PALM2/Gemini key check
|
174 |
+
if self.has_API_key(self.cfg_private['google_palm']['google_palm_api']) and self.has_API_key(self.cfg_private['google_palm']['project_id']) and self.has_API_key(self.cfg_private['google_palm']['location']):
|
175 |
+
google_results = self.check_google_vertex_genai_api_key()
|
176 |
+
if google_results['palm2']:
|
177 |
+
present_keys.append('Palm2 (Valid)')
|
178 |
+
else:
|
179 |
+
present_keys.append('Palm2 (Invalid)')
|
180 |
+
if google_results['gemini']:
|
181 |
+
present_keys.append('Gemini (Valid)')
|
182 |
+
else:
|
183 |
+
present_keys.append('Gemini (Invalid)')
|
184 |
+
else:
|
185 |
+
missing_keys.append('Google VertexAI/GenAI')
|
186 |
+
|
187 |
+
# Google OCR key check
|
188 |
+
if self.has_API_key(self.cfg_private['google_palm']['google_palm_api']) and self.has_API_key(self.cfg_private['google_palm']['project_id']) and self.has_API_key(self.cfg_private['google_palm']['location']):
|
189 |
+
is_valid = self.check_google_ocr_api_key()
|
190 |
+
if is_valid:
|
191 |
+
present_keys.append('Google OCR (Valid)')
|
192 |
+
else:
|
193 |
+
present_keys.append('Google OCR (Invalid)')
|
194 |
+
else:
|
195 |
+
missing_keys.append('Google OCR')
|
196 |
+
|
197 |
+
# Mistral key check
|
198 |
+
if self.has_API_key(self.cfg_private['mistral']['mistral_key']):
|
199 |
+
is_valid = self.check_mistral_api_key()
|
200 |
+
if is_valid:
|
201 |
+
present_keys.append('Mistral (Valid)')
|
202 |
+
else:
|
203 |
+
present_keys.append('Mistral (Invalid)')
|
204 |
+
else:
|
205 |
+
missing_keys.append('Mistral')
|
206 |
+
|
207 |
+
|
208 |
+
if self.has_API_key(self.cfg_private['here']['api_key']):
|
209 |
+
present_keys.append('HERE Geocode (Valid)')
|
210 |
+
else:
|
211 |
+
missing_keys.append('HERE Geocode (Invalid)')
|
212 |
+
|
213 |
+
if self.has_API_key(self.cfg_private['open_cage_geocode']['api_key']):
|
214 |
+
present_keys.append('OpenCage Geocode (Valid)')
|
215 |
+
else:
|
216 |
+
missing_keys.append('OpenCage Geocode (Invalid)')
|
217 |
+
|
218 |
+
# Create a report string
|
219 |
+
report = "API Key Status Report:\n"
|
220 |
+
report += "Present Keys: " + ", ".join(present_keys) + "\n"
|
221 |
+
report += "Missing Keys: " + ", ".join(missing_keys) + "\n"
|
222 |
+
|
223 |
+
# print(report)
|
224 |
+
return present_keys, missing_keys, self.formatted_date
|
vouchervision/DEP_prompt_catalog.py
ADDED
@@ -0,0 +1,1322 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass
|
2 |
+
from langchain_core.pydantic_v1 import BaseModel, Field, create_model
|
3 |
+
import yaml, json
|
4 |
+
|
5 |
+
|
6 |
+
# catalog = PromptCatalog(OCR="Sample OCR text", domain_knowledge_example="Sample domain knowledge", similarity="0.9")
|
7 |
+
|
8 |
+
|
9 |
+
### Required if you want to use the Pydantic JSON parser for langchain
|
10 |
+
class SLTPvA(BaseModel):
|
11 |
+
catalogNumber: str = Field(description="Barcode identifier, typically a number with at least 6 digits, but fewer than 30 digits.")
|
12 |
+
order: str = Field(description="The full scientific name of the order in which the taxon is classified. Order must be capitalized.")
|
13 |
+
family: str = Field(description="The full scientific name of the family in which the taxon is classified. Family must be capitalized.")
|
14 |
+
scientificName: str = Field(description="The scientific name of the taxon including genus, specific epithet, and any lower classifications.")
|
15 |
+
scientificNameAuthorship: str = Field(description="The authorship information for the scientificName formatted according to the conventions of the applicable Darwin Core nomenclaturalCode.")
|
16 |
+
genus: str = Field(description="Taxonomic determination to genus. Genus must be capitalized. If genus is not present use the taxonomic family name followed by the word 'indet'.")
|
17 |
+
subgenus: str = Field(description="The full scientific name of the subgenus in which the taxon is classified. Values should include the genus to avoid homonym confusion.")
|
18 |
+
specificEpithet: str = Field(description="The name of the first or species epithet of the scientificName. Only include the species epithet.")
|
19 |
+
infraspecificEpithet: str = Field(description="The name of the lowest or terminal infraspecific epithet of the scientificName, excluding any rank designation.")
|
20 |
+
identifiedBy: str = Field(description="A comma separated list of names of people, groups, or organizations who assigned the taxon to the subject organism. This is not the specimen collector.")
|
21 |
+
recordedBy: str = Field(description="A comma separated list of names of people, groups, or organizations responsible for observing, recording, collecting, or presenting the original specimen. The primary collector or observer should be listed first.")
|
22 |
+
recordNumber: str = Field(description="An identifier given to the occurrence at the time it was recorded. Often serves as a link between str = Field notes and an occurrence record, such as a specimen collector's number.")
|
23 |
+
verbatimEventDate: str = Field(description="The verbatim original representation of the date and time information for when the specimen was collected. Date of collection exactly as it appears on the label. Do not change the format or correct typos.")
|
24 |
+
eventDate:str = Field(description=" Date the specimen was collected formatted as year-month-day. If specific components of the date are unknown, they should be replaced with zeros. Examples '0000-00-00' if the entire date is unknown, 'YYYY-00-00' if only the year is known, and 'YYYY-MM-00' if year and month are known but day is not.")
|
25 |
+
habitat: str = Field(description="A category or description of the habitat in which the specimen collection event occurred.")
|
26 |
+
occurrenceRemarks: str = Field(description="Text describing the specimen's geographic location. Text describing the appearance of the specimen. A statement about the presence or absence of a taxon at a the collection location. Text describing the significance of the specimen, such as a specific expedition or notable collection. Description of plant features such as leaf shape, size, color, stem texture, height, flower structure, scent, fruit or seed characteristics, root system type, overall growth habit and form, any notable aroma or secretions, presence of hairs or bristles, and any other distinguishing morphological or physiological characteristics.")
|
27 |
+
country: str = Field(description="The name of the country or major administrative unit in which the specimen was originally collected.")
|
28 |
+
stateProvince: str = Field(description="The name of the next smaller administrative region than country (state, province, canton, department, region, etc.) in which the specimen was originally collected.")
|
29 |
+
county:str = Field(description=" The full, unabbreviated name of the next smaller administrative region than stateProvince (county, shire, department, parish etc.) in which the specimen was originally collected.")
|
30 |
+
municipality: str = Field(description="The full, unabbreviated name of the next smaller administrative region than county (city, municipality, etc.) in which the specimen was originally collected.")
|
31 |
+
locality: str = Field(description="Description of geographic location, landscape, landmarks, regional features, nearby places, or any contextual information aiding in pinpointing the exact origin or location of the specimen.")
|
32 |
+
degreeOfEstablishment: str = Field(description="Cultivated plants are intentionally grown by humans. In text descriptions, look for planting dates, garden locations, ornamental, cultivar names, garden, or farm to indicate cultivated plant. Other possible designiations include - unknown, native, captive, cultivated, released, failing, casual, reproducing, established, colonising, invasive, widespreadInvasive. Based on the OCR text, assign the most suitable designiation and use the term unknown as the default designiation.")
|
33 |
+
decimalLatitude: str = Field(description="Latitude decimal coordinate. Correct and convert the verbatim location coordinates to conform with the decimal degrees GPS coordinate format.")
|
34 |
+
decimalLongitude: str = Field(description="Longitude decimal coordinate. Correct and convert the verbatim location coordinates to conform with the decimal degrees GPS coordinate format.")
|
35 |
+
verbatimCoordinates: str = Field(description="Verbatim location coordinates as they appear on the label. Do not convert formats. Possible coordinate types include [Lat, Long, UTM, TRS].")
|
36 |
+
minimumElevationInMeters: str = Field(description="Minimum elevation or altitude in meters. Only if units are explicit then convert from feet ('ft' or 'ft.'' or 'feet') to meters ('m' or 'm.' or 'meters'). Round to integer.")
|
37 |
+
maximumElevationInMeters: str = Field(description="Maximum elevation or altitude in meters. If only one elevation is present, then max_elevation should be set to the null_value. Only if units are explicit then convert from feet ('ft' or 'ft.' or 'feet') to meters ('m' or 'm.' or 'meters'). Round to integer.")
|
38 |
+
|
39 |
+
|
40 |
+
|
41 |
+
|
42 |
+
@dataclass
|
43 |
+
class PromptCatalog:
|
44 |
+
domain_knowledge_example: str = ""
|
45 |
+
similarity: str = ""
|
46 |
+
OCR: str = ""
|
47 |
+
n_fields: int = 0
|
48 |
+
|
49 |
+
# def PROMPT_UMICH_skeleton_all_asia(self, OCR=None, domain_knowledge_example=None, similarity=None):
|
50 |
+
def prompt_v1_verbose(self, OCR=None, domain_knowledge_example=None, similarity=None):
|
51 |
+
self.OCR = OCR or self.OCR
|
52 |
+
self.domain_knowledge_example = domain_knowledge_example or self.domain_knowledge_example
|
53 |
+
self.similarity = similarity or self.similarity
|
54 |
+
self.n_fields = 22 or self.n_fields
|
55 |
+
|
56 |
+
set_rules = """
|
57 |
+
Please note that your task is to generate a dictionary, following the below rules:
|
58 |
+
1. Refactor the unstructured OCR text into a dictionary based on the reference dictionary structure (ref_dict).
|
59 |
+
2. Each field of OCR corresponds to a column of the ref_dict. You should correctly map the values from OCR to the respective fields in ref_dict.
|
60 |
+
3. If the OCR is mostly empty and contains substantially less text than the ref_dict examples, then only return "None".
|
61 |
+
4. If there is a field in the ref_dict that does not have a corresponding value in the OCR text, fill it based on your knowledge but don't generate new information.
|
62 |
+
5. Do not use any text from the ref_dict values in the new dict, but you must use the headers from ref_dict.
|
63 |
+
6. Duplicate dictionary fields are not allowed.
|
64 |
+
7. Only return the new dictionary. You should not explain your answer.
|
65 |
+
8. Your output should be a Python dictionary represented as a JSON string.
|
66 |
+
"""
|
67 |
+
|
68 |
+
umich_all_asia_rules = """{
|
69 |
+
"Catalog Number": {
|
70 |
+
"format": "[Catalog Number]",
|
71 |
+
"null_value": "",
|
72 |
+
"description": "The barcode identifier, typically a number with at least 6 digits, but fewer than 30 digits"
|
73 |
+
},
|
74 |
+
"Genus": {
|
75 |
+
"format": "[Genus] or '[Family] indet' if no genus",
|
76 |
+
"null_value": "",
|
77 |
+
"description": "Taxonomic determination to genus, do capitalize genus"
|
78 |
+
},
|
79 |
+
"Species": {
|
80 |
+
"format": "[species] or 'indet' if no species",
|
81 |
+
"null_value": "",
|
82 |
+
"description": "Taxonomic determination to species, do not capitalize species"
|
83 |
+
},
|
84 |
+
"subspecies": {
|
85 |
+
"format": "[subspecies]",
|
86 |
+
"null_value": "",
|
87 |
+
"description": "Taxonomic determination to subspecies (subsp.)"
|
88 |
+
},
|
89 |
+
"variety": {
|
90 |
+
"format": "[variety]",
|
91 |
+
"null_value": "",
|
92 |
+
"description": "Taxonomic determination to variety (var)"
|
93 |
+
},
|
94 |
+
"forma": {
|
95 |
+
"format": "[form]",
|
96 |
+
"null_value": "",
|
97 |
+
"description": "Taxonomic determination to form (f.)"
|
98 |
+
},
|
99 |
+
"Country": {
|
100 |
+
"format": "[Country]",
|
101 |
+
"null_value": "",
|
102 |
+
"description": "Country that corresponds to the current geographic location of collection; capitalize first letter of each word; use the entire location name even if an abbreviation is given"
|
103 |
+
},
|
104 |
+
"State": {
|
105 |
+
"format": "[Adm. Division 1]",
|
106 |
+
"null_value": "",
|
107 |
+
"description": "Administrative division 1 that corresponds to the current geographic location of collection; capitalize first letter of each word"
|
108 |
+
},
|
109 |
+
"County": {
|
110 |
+
"format": "[Adm. Division 2]",
|
111 |
+
"null_value": "",
|
112 |
+
"description": "Administrative division 2 that corresponds to the current geographic location of collection; capitalize first letter of each word"
|
113 |
+
},
|
114 |
+
"Locality Name": {
|
115 |
+
"format": "verbatim, if no geographic info: 'no data provided on label of catalog no: [######]', or if illegible: 'locality present but illegible/not translated for catalog no: #######', or if no named locality: 'no named locality for catalog no: #######'",
|
116 |
+
"description": "Description of geographic location or landscape"
|
117 |
+
},
|
118 |
+
"Min Elevation": {
|
119 |
+
"format": "elevation integer",
|
120 |
+
"null_value": "",
|
121 |
+
"description": "Elevation or altitude in meters, convert from feet to meters if 'm' or 'meters' is not in the text and round to integer, default field for elevation if a range is not given"
|
122 |
+
},
|
123 |
+
"Max Elevation": {
|
124 |
+
"format": "elevation integer",
|
125 |
+
"null_value": "",
|
126 |
+
"description": "Elevation or altitude in meters, convert from feet to meters if 'm' or 'meters' is not in the text and round to integer, maximum elevation if there are two elevations listed but '' otherwise"
|
127 |
+
},
|
128 |
+
"Elevation Units": {
|
129 |
+
"format": "m",
|
130 |
+
"null_value": "",
|
131 |
+
"description": "'m' only if an elevation is present"
|
132 |
+
},
|
133 |
+
"Verbatim Coordinates": {
|
134 |
+
"format": "[Lat, Long | UTM | TRS]",
|
135 |
+
"null_value": "",
|
136 |
+
"description": "Verbatim coordinates as they appear on the label, fix typos to match standardized GPS coordinate format"
|
137 |
+
},
|
138 |
+
"Datum": {
|
139 |
+
"format": "[WGS84, NAD23 etc.]",
|
140 |
+
"null_value": "",
|
141 |
+
"description": "GPS Datum of coordinates on label; empty string "" if GPS coordinates are not in OCR"
|
142 |
+
},
|
143 |
+
"Cultivated": {
|
144 |
+
"format": "yes",
|
145 |
+
"null_value": "",
|
146 |
+
"description": "Indicates if specimen was grown in cultivation"
|
147 |
+
},
|
148 |
+
"Habitat": {
|
149 |
+
"format": "verbatim",
|
150 |
+
"null_value": "",
|
151 |
+
"description": "Description of habitat or location where specimen was collected, ignore descriptions of the plant itself"
|
152 |
+
},
|
153 |
+
"Collectors": {
|
154 |
+
"format": "[Collector]",
|
155 |
+
"null_value": "not present",
|
156 |
+
"description": "Full name of person (i.e., agent) who collected the specimen; if more than one person then separate the names with commas"
|
157 |
+
},
|
158 |
+
"Collector Number": {
|
159 |
+
"format": "[Collector No.]",
|
160 |
+
"null_value": "s.n.",
|
161 |
+
"description": "Sequential number assigned to collection, associated with the collector"
|
162 |
+
},
|
163 |
+
"Verbatim Date": {
|
164 |
+
"format": "verbatim",
|
165 |
+
"null_value": "s.d.",
|
166 |
+
"description": "Date of collection exactly as it appears on the label"
|
167 |
+
},
|
168 |
+
"Date": {
|
169 |
+
"format": "[yyyy-mm-dd]",
|
170 |
+
"null_value": "",
|
171 |
+
"description": "Date of collection formatted as year, month, and day; zeros may be used for unknown values i.e., 0000-00-00 if no date, YYYY-00-00 if only year, YYYY-MM-00 if no day"
|
172 |
+
},
|
173 |
+
"End Date": {
|
174 |
+
"format": "[yyyy-mm-dd]",
|
175 |
+
"null_value": "",
|
176 |
+
"description": "If date range is listed, later date of collection range"
|
177 |
+
}
|
178 |
+
}"""
|
179 |
+
|
180 |
+
structure = """{"Dictionary":
|
181 |
+
{
|
182 |
+
"Catalog Number": [Catalog Number],
|
183 |
+
"Genus": [Genus],
|
184 |
+
"Species": [species],
|
185 |
+
"subspecies": [subspecies],
|
186 |
+
"variety": [variety],
|
187 |
+
"forma": [forma],
|
188 |
+
"Country": [Country],
|
189 |
+
"State": [State],
|
190 |
+
"County": [County],
|
191 |
+
"Locality Name": [Locality Name],
|
192 |
+
"Min Elevation": [Min Elevation],
|
193 |
+
"Max Elevation": [Max Elevation],
|
194 |
+
"Elevation Units": [Elevation Units],
|
195 |
+
"Verbatim Coordinates": [Verbatim Coordinates],
|
196 |
+
"Datum": [Datum],
|
197 |
+
"Cultivated": [Cultivated],
|
198 |
+
"Habitat": [Habitat],
|
199 |
+
"Collectors": [Collectors],
|
200 |
+
"Collector Number": [Collector Number],
|
201 |
+
"Verbatim Date": [Verbatim Date],
|
202 |
+
"Date": [Date],
|
203 |
+
"End Date": [End Date]
|
204 |
+
},
|
205 |
+
"SpeciesName": {"taxonomy": [Genus_species]}}"""
|
206 |
+
|
207 |
+
prompt = f"""I'm providing you with a set of rules, an unstructured OCR text, and a reference dictionary (domain knowledge). Your task is to convert the OCR text into a structured dictionary that matches the structure of the reference dictionary. Please follow the rules strictly.
|
208 |
+
The rules are as follows:
|
209 |
+
{set_rules}
|
210 |
+
The unstructured OCR text is:
|
211 |
+
{self.OCR}
|
212 |
+
The reference dictionary, which provides an example of the output structure and has an embedding distance of {self.similarity} to the OCR, is:
|
213 |
+
{self.domain_knowledge_example}
|
214 |
+
Some dictionary fields have special requirements. These requirements specify the format for each field, and are given below:
|
215 |
+
{umich_all_asia_rules}
|
216 |
+
Please refactor the OCR text into a dictionary, following the rules and the reference structure:
|
217 |
+
{structure}
|
218 |
+
"""
|
219 |
+
|
220 |
+
xlsx_headers = ["Catalog Number","Genus","Species","subspecies","variety","forma","Country","State","County","Locality Name","Min Elevation","Max Elevation","Elevation Units","Verbatim Coordinates","Datum","Cultivated","Habitat","Collectors","Collector Number","Verbatim Date","Date","End Date"]
|
221 |
+
|
222 |
+
|
223 |
+
return prompt, self.n_fields, xlsx_headers
|
224 |
+
|
225 |
+
def prompt_v1_verbose_noDomainKnowledge(self, OCR=None):
|
226 |
+
self.OCR = OCR or self.OCR
|
227 |
+
self.n_fields = 22 or self.n_fields
|
228 |
+
|
229 |
+
set_rules = """
|
230 |
+
Please note that your task is to generate a dictionary, following the below rules:
|
231 |
+
1. Refactor the unstructured OCR text into a dictionary based on the reference dictionary structure (ref_dict).
|
232 |
+
2. Each field of OCR corresponds to a column of the ref_dict. You should correctly map the values from OCR to the respective fields in ref_dict.
|
233 |
+
3. If the OCR is mostly empty and contains substantially less text than the ref_dict examples, then only return "None".
|
234 |
+
4. If there is a field in the ref_dict that does not have a corresponding value in the OCR text, fill it based on your knowledge but don't generate new information.
|
235 |
+
5. Do not use any text from the ref_dict values in the new dict, but you must use the headers from ref_dict.
|
236 |
+
6. Duplicate dictionary fields are not allowed.
|
237 |
+
7. Only return the new dictionary. You should not explain your answer.
|
238 |
+
8. Your output should be a Python dictionary represented as a JSON string.
|
239 |
+
"""
|
240 |
+
|
241 |
+
umich_all_asia_rules = """{
|
242 |
+
"Catalog Number": {
|
243 |
+
"format": "[Catalog Number]",
|
244 |
+
"null_value": "",
|
245 |
+
"description": "The barcode identifier, typically a number with at least 6 digits, but fewer than 30 digits"
|
246 |
+
},
|
247 |
+
"Genus": {
|
248 |
+
"format": "[Genus] or '[Family] indet' if no genus",
|
249 |
+
"null_value": "",
|
250 |
+
"description": "Taxonomic determination to genus, do capitalize genus"
|
251 |
+
},
|
252 |
+
"Species": {
|
253 |
+
"format": "[species] or 'indet' if no species",
|
254 |
+
"null_value": "",
|
255 |
+
"description": "Taxonomic determination to species, do not capitalize species"
|
256 |
+
},
|
257 |
+
"subspecies": {
|
258 |
+
"format": "[subspecies]",
|
259 |
+
"null_value": "",
|
260 |
+
"description": "Taxonomic determination to subspecies (subsp.)"
|
261 |
+
},
|
262 |
+
"variety": {
|
263 |
+
"format": "[variety]",
|
264 |
+
"null_value": "",
|
265 |
+
"description": "Taxonomic determination to variety (var)"
|
266 |
+
},
|
267 |
+
"forma": {
|
268 |
+
"format": "[form]",
|
269 |
+
"null_value": "",
|
270 |
+
"description": "Taxonomic determination to form (f.)"
|
271 |
+
},
|
272 |
+
"Country": {
|
273 |
+
"format": "[Country]",
|
274 |
+
"null_value": "",
|
275 |
+
"description": "Country that corresponds to the current geographic location of collection; capitalize first letter of each word; use the entire location name even if an abbreviation is given"
|
276 |
+
},
|
277 |
+
"State": {
|
278 |
+
"format": "[Adm. Division 1]",
|
279 |
+
"null_value": "",
|
280 |
+
"description": "Administrative division 1 that corresponds to the current geographic location of collection; capitalize first letter of each word"
|
281 |
+
},
|
282 |
+
"County": {
|
283 |
+
"format": "[Adm. Division 2]",
|
284 |
+
"null_value": "",
|
285 |
+
"description": "Administrative division 2 that corresponds to the current geographic location of collection; capitalize first letter of each word"
|
286 |
+
},
|
287 |
+
"Locality Name": {
|
288 |
+
"format": "verbatim, if no geographic info: 'no data provided on label of catalog no: [######]', or if illegible: 'locality present but illegible/not translated for catalog no: #######', or if no named locality: 'no named locality for catalog no: #######'",
|
289 |
+
"description": "Description of geographic location or landscape"
|
290 |
+
},
|
291 |
+
"Min Elevation": {
|
292 |
+
"format": "elevation integer",
|
293 |
+
"null_value": "",
|
294 |
+
"description": "Elevation or altitude in meters, convert from feet to meters if 'm' or 'meters' is not in the text and round to integer, default field for elevation if a range is not given"
|
295 |
+
},
|
296 |
+
"Max Elevation": {
|
297 |
+
"format": "elevation integer",
|
298 |
+
"null_value": "",
|
299 |
+
"description": "Elevation or altitude in meters, convert from feet to meters if 'm' or 'meters' is not in the text and round to integer, maximum elevation if there are two elevations listed but '' otherwise"
|
300 |
+
},
|
301 |
+
"Elevation Units": {
|
302 |
+
"format": "m",
|
303 |
+
"null_value": "",
|
304 |
+
"description": "'m' only if an elevation is present"
|
305 |
+
},
|
306 |
+
"Verbatim Coordinates": {
|
307 |
+
"format": "[Lat, Long | UTM | TRS]",
|
308 |
+
"null_value": "",
|
309 |
+
"description": "Verbatim coordinates as they appear on the label, fix typos to match standardized GPS coordinate format"
|
310 |
+
},
|
311 |
+
"Datum": {
|
312 |
+
"format": "[WGS84, NAD23 etc.]",
|
313 |
+
"null_value": "",
|
314 |
+
"description": "GPS Datum of coordinates on label; empty string "" if GPS coordinates are not in OCR"
|
315 |
+
},
|
316 |
+
"Cultivated": {
|
317 |
+
"format": "yes",
|
318 |
+
"null_value": "",
|
319 |
+
"description": "Indicates if specimen was grown in cultivation"
|
320 |
+
},
|
321 |
+
"Habitat": {
|
322 |
+
"format": "verbatim",
|
323 |
+
"null_value": "",
|
324 |
+
"description": "Description of habitat or location where specimen was collected, ignore descriptions of the plant itself"
|
325 |
+
},
|
326 |
+
"Collectors": {
|
327 |
+
"format": "[Collector]",
|
328 |
+
"null_value": "not present",
|
329 |
+
"description": "Full name of person (i.e., agent) who collected the specimen; if more than one person then separate the names with commas"
|
330 |
+
},
|
331 |
+
"Collector Number": {
|
332 |
+
"format": "[Collector No.]",
|
333 |
+
"null_value": "s.n.",
|
334 |
+
"description": "Sequential number assigned to collection, associated with the collector"
|
335 |
+
},
|
336 |
+
"Verbatim Date": {
|
337 |
+
"format": "verbatim",
|
338 |
+
"null_value": "s.d.",
|
339 |
+
"description": "Date of collection exactly as it appears on the label"
|
340 |
+
},
|
341 |
+
"Date": {
|
342 |
+
"format": "[yyyy-mm-dd]",
|
343 |
+
"null_value": "",
|
344 |
+
"description": "Date of collection formatted as year, month, and day; zeros may be used for unknown values i.e., 0000-00-00 if no date, YYYY-00-00 if only year, YYYY-MM-00 if no day"
|
345 |
+
},
|
346 |
+
"End Date": {
|
347 |
+
"format": "[yyyy-mm-dd]",
|
348 |
+
"null_value": "",
|
349 |
+
"description": "If date range is listed, later date of collection range"
|
350 |
+
}
|
351 |
+
}"""
|
352 |
+
|
353 |
+
structure = """{"Dictionary":
|
354 |
+
{
|
355 |
+
"Catalog Number": [Catalog Number],
|
356 |
+
"Genus": [Genus],
|
357 |
+
"Species": [species],
|
358 |
+
"subspecies": [subspecies],
|
359 |
+
"variety": [variety],
|
360 |
+
"forma": [forma],
|
361 |
+
"Country": [Country],
|
362 |
+
"State": [State],
|
363 |
+
"County": [County],
|
364 |
+
"Locality Name": [Locality Name],
|
365 |
+
"Min Elevation": [Min Elevation],
|
366 |
+
"Max Elevation": [Max Elevation],
|
367 |
+
"Elevation Units": [Elevation Units],
|
368 |
+
"Verbatim Coordinates": [Verbatim Coordinates],
|
369 |
+
"Datum": [Datum],
|
370 |
+
"Cultivated": [Cultivated],
|
371 |
+
"Habitat": [Habitat],
|
372 |
+
"Collectors": [Collectors],
|
373 |
+
"Collector Number": [Collector Number],
|
374 |
+
"Verbatim Date": [Verbatim Date],
|
375 |
+
"Date": [Date],
|
376 |
+
"End Date": [End Date]
|
377 |
+
},
|
378 |
+
"SpeciesName": {"taxonomy": [Genus_species]}}"""
|
379 |
+
|
380 |
+
prompt = f"""I'm providing you with a set of rules, an unstructured OCR text, and a reference dictionary (domain knowledge). Your task is to convert the OCR text into a structured dictionary that matches the structure of the reference dictionary. Please follow the rules strictly.
|
381 |
+
The rules are as follows:
|
382 |
+
{set_rules}
|
383 |
+
The unstructured OCR text is:
|
384 |
+
{self.OCR}
|
385 |
+
Some dictionary fields have special requirements. These requirements specify the format for each field, and are given below:
|
386 |
+
{umich_all_asia_rules}
|
387 |
+
Please refactor the OCR text into a dictionary, following the rules and the reference structure:
|
388 |
+
{structure}
|
389 |
+
"""
|
390 |
+
|
391 |
+
xlsx_headers = ["Catalog Number","Genus","Species","subspecies","variety","forma","Country","State","County","Locality Name","Min Elevation","Max Elevation","Elevation Units","Verbatim Coordinates","Datum","Cultivated","Habitat","Collectors","Collector Number","Verbatim Date","Date","End Date"]
|
392 |
+
|
393 |
+
return prompt, self.n_fields, xlsx_headers
|
394 |
+
|
395 |
+
def prompt_v2_json_rules(self, OCR=None):
|
396 |
+
self.OCR = OCR or self.OCR
|
397 |
+
self.n_fields = 26 or self.n_fields
|
398 |
+
|
399 |
+
set_rules = """
|
400 |
+
1. Refactor the unstructured OCR text into a dictionary based on the JSON structure outlined below.
|
401 |
+
2. You should map the unstructured OCR text to the appropriate JSON key and then populate the field based on its rules.
|
402 |
+
3. Some JSON key fields are permitted to remain empty if the corresponding information is not found in the unstructured OCR text.
|
403 |
+
4. Ignore any information in the OCR text that doesn't fit into the defined JSON structure.
|
404 |
+
5. Duplicate dictionary fields are not allowed.
|
405 |
+
6. Ensure that all JSON keys are in lowercase.
|
406 |
+
7. Ensure that new JSON field values follow sentence case capitalization.
|
407 |
+
7. Ensure all key-value pairs in the JSON dictionary strictly adhere to the format and data types specified in the template.
|
408 |
+
8. Ensure the output JSON string is valid JSON format. It should not have trailing commas or unquoted keys.
|
409 |
+
9. Only return a JSON dictionary represented as a string. You should not explain your answer.
|
410 |
+
"""
|
411 |
+
|
412 |
+
dictionary_field_format_descriptions = """
|
413 |
+
The next section of instructions outlines how to format the JSON dictionary. The keys are the same as those of the final formatted JSON object.
|
414 |
+
For each key there is a format requirement that specifies how to transcribe the information for that key.
|
415 |
+
The possible formatting options are:
|
416 |
+
1. "verbatim transcription" - field is populated with verbatim text from the unformatted OCR.
|
417 |
+
2. "spell check transcription" - field is populated with spelling corrected text from the unformatted OCR.
|
418 |
+
3. "boolean yes no" - field is populated with only yes or no.
|
419 |
+
4. "integer" - field is populated with only an integer.
|
420 |
+
5. "[list]" - field is populated from one of the values in the list.
|
421 |
+
6. "yyyy-mm-dd" - field is populated with a date in the format year-month-day.
|
422 |
+
The desired null value is also given. Populate the field with the null value of the information for that key is not present in the unformatted OCR text.
|
423 |
+
"""
|
424 |
+
|
425 |
+
json_template_rules = """
|
426 |
+
{"Dictionary":{
|
427 |
+
"catalog_number": {
|
428 |
+
"format": "verbatim transcription",
|
429 |
+
"null_value": "",
|
430 |
+
"description": "The barcode identifier, typically a number with at least 6 digits, but fewer than 30 digits."
|
431 |
+
},
|
432 |
+
"genus": {
|
433 |
+
"format": "verbatim transcription",
|
434 |
+
"null_value": "",
|
435 |
+
"description": "Taxonomic determination to genus. Genus must be capitalized. If genus is not present use the taxonomic family name followed by the word 'indet'."
|
436 |
+
},
|
437 |
+
"species": {
|
438 |
+
"format": "verbatim transcription",
|
439 |
+
"null_value": "",
|
440 |
+
"description": "Taxonomic determination to species, do not capitalize species."
|
441 |
+
},
|
442 |
+
"subspecies": {
|
443 |
+
"format": "verbatim transcription",
|
444 |
+
"null_value": "",
|
445 |
+
"description": "Taxonomic determination to subspecies (subsp.)."
|
446 |
+
},
|
447 |
+
"variety": {
|
448 |
+
"format": "verbatim transcription",
|
449 |
+
"null_value": "",
|
450 |
+
"description": "Taxonomic determination to variety (var)."
|
451 |
+
},
|
452 |
+
"forma": {
|
453 |
+
"format": "verbatim transcription",
|
454 |
+
"null_value": "",
|
455 |
+
"description": "Taxonomic determination to form (f.)."
|
456 |
+
},
|
457 |
+
"country": {
|
458 |
+
"format": "spell check transcription",
|
459 |
+
"null_value": "",
|
460 |
+
"description": "Country that corresponds to the current geographic location of collection. Capitalize first letter of each word. If abbreviation is given populate field with the full spelling of the country's name."
|
461 |
+
},
|
462 |
+
"state": {
|
463 |
+
"format": "spell check transcription",
|
464 |
+
"null_value": "",
|
465 |
+
"description": "Administrative division 1 that corresponds to the current geographic location of collection. Capitalize first letter of each word. Administrative division 1 is equivalent to a U.S. State."
|
466 |
+
},
|
467 |
+
"county": {
|
468 |
+
"format": "spell check transcription",
|
469 |
+
"null_value": "",
|
470 |
+
"description": "Administrative division 2 that corresponds to the current geographic location of collection; capitalize first letter of each word. Administrative division 2 is equivalent to a U.S. county, parish, borough."
|
471 |
+
},
|
472 |
+
"locality_name": {
|
473 |
+
"format": "verbatim transcription",
|
474 |
+
"null_value": "",
|
475 |
+
"description": "Description of geographic location, landscape, landmarks, regional features, nearby places, or any contextual information aiding in pinpointing the exact origin or site of the specimen."
|
476 |
+
},
|
477 |
+
"min_elevation": {
|
478 |
+
"format": "integer",
|
479 |
+
"null_value": "",
|
480 |
+
"description": "Minimum elevation or altitude in meters. Only if units are explicit then convert from feet ('ft' or 'ft.' or 'feet') to meters ('m' or 'm.' or 'meters'). Round to integer."
|
481 |
+
},
|
482 |
+
"max_elevation": {
|
483 |
+
"format": "integer",
|
484 |
+
"null_value": "",
|
485 |
+
"description": "Maximum elevation or altitude in meters. If only one elevation is present, then max_elevation should be set to the null_value. Only if units are explicit then convert from feet ('ft' or 'ft.' or 'feet') to meters ('m' or 'm.' or 'meters'). Round to integer."
|
486 |
+
},
|
487 |
+
"elevation_units": {
|
488 |
+
"format": "spell check transcription",
|
489 |
+
"null_value": "",
|
490 |
+
"description": "Elevation units must be meters. If min_elevation field is populated, then elevation_units: 'm'. Otherwise elevation_units: ''."
|
491 |
+
},
|
492 |
+
"verbatim_coordinates": {
|
493 |
+
"format": "verbatim transcription",
|
494 |
+
"null_value": "",
|
495 |
+
"description": "Verbatim location coordinates as they appear on the label. Do not convert formats. Possible coordinate types are one of [Lat, Long, UTM, TRS]."
|
496 |
+
},
|
497 |
+
"decimal_coordinates": {
|
498 |
+
"format": "spell check transcription",
|
499 |
+
"null_value": "",
|
500 |
+
"description": "Correct and convert the verbatim location coordinates to conform with the decimal degrees GPS coordinate format."
|
501 |
+
},
|
502 |
+
"datum": {
|
503 |
+
"format": "[WGS84, WGS72, WGS66, WGS60, NAD83, NAD27, OSGB36, ETRS89, ED50, GDA94, JGD2011, Tokyo97, KGD2002, TWD67, TWD97, BJS54, XAS80, GCJ-02, BD-09, PZ-90.11, GTRF, CGCS2000, ITRF88, ITRF89, ITRF90, ITRF91, ITRF92, ITRF93, ITRF94, ITRF96, ITRF97, ITRF2000, ITRF2005, ITRF2008, ITRF2014, Hong Kong Principal Datum, SAD69]",
|
504 |
+
"null_value": "",
|
505 |
+
"description": "Datum of location coordinates. Possible values are include in the format list. Leave field blank if unclear."
|
506 |
+
},
|
507 |
+
"cultivated": {
|
508 |
+
"format": "boolean yes no",
|
509 |
+
"null_value": "",
|
510 |
+
"description": "Cultivated plants are intentionally grown by humans. In text descriptions, look for planting dates, garden locations, ornamental, cultivar names, garden, or farm to indicate cultivated plant."
|
511 |
+
},
|
512 |
+
"habitat": {
|
513 |
+
"format": "verbatim transcription",
|
514 |
+
"null_value": "",
|
515 |
+
"description": "Description of a plant's habitat or the location where the specimen was collected. Ignore descriptions of the plant itself."
|
516 |
+
},
|
517 |
+
"plant_description": {
|
518 |
+
"format": "verbatim transcription",
|
519 |
+
"null_value": "",
|
520 |
+
"description": "Description of plant features such as leaf shape, size, color, stem texture, height, flower structure, scent, fruit or seed characteristics, root system type, overall growth habit and form, any notable aroma or secretions, presence of hairs or bristles, and any other distinguishing morphological or physiological characteristics."
|
521 |
+
},
|
522 |
+
"collectors": {
|
523 |
+
"format": "verbatim transcription",
|
524 |
+
"null_value": "not present",
|
525 |
+
"description": "Full name(s) of the individual(s) responsible for collecting the specimen. When multiple collectors are involved, their names should be separated by commas."
|
526 |
+
},
|
527 |
+
"collector_number": {
|
528 |
+
"format": "verbatim transcription",
|
529 |
+
"null_value": "s.n.",
|
530 |
+
"description": "Unique identifier or number that denotes the specific collecting event and associated with the collector."
|
531 |
+
},
|
532 |
+
"determined_by": {
|
533 |
+
"format": "verbatim transcription",
|
534 |
+
"null_value": "",
|
535 |
+
"description": "Full name of the individual responsible for determining the taxanomic name of the specimen. Sometimes the name will be near to the characters 'det' to denote determination. This name may be isolated from other names in the unformatted OCR text."
|
536 |
+
},
|
537 |
+
"multiple_names": {
|
538 |
+
"format": "boolean yes no",
|
539 |
+
"null_value": "",
|
540 |
+
"description": "Indicate whether multiple people or collector names are present in the unformatted OCR text. If you see more than one person's name the value is 'yes'; otherwise the value is 'no'."
|
541 |
+
},
|
542 |
+
"verbatim_date": {
|
543 |
+
"format": "verbatim transcription",
|
544 |
+
"null_value": "s.d.",
|
545 |
+
"description": "Date of collection exactly as it appears on the label. Do not change the format or correct typos."
|
546 |
+
},
|
547 |
+
"date": {
|
548 |
+
"format": "yyyy-mm-dd",
|
549 |
+
"null_value": "",
|
550 |
+
"description": "Date the specimen was collected formatted as year-month-day. If specific components of the date are unknown, they should be replaced with zeros. Examples: '0000-00-00' if the entire date is unknown, 'YYYY-00-00' if only the year is known, and 'YYYY-MM-00' if year and month are known but day is not."
|
551 |
+
},
|
552 |
+
"end_date": {
|
553 |
+
"format": "yyyy-mm-dd",
|
554 |
+
"null_value": "",
|
555 |
+
"description": "If a date range is provided, this represents the later or ending date of the collection period, formatted as year-month-day. If specific components of the date are unknown, they should be replaced with zeros. Examples: '0000-00-00' if the entire end date is unknown, 'YYYY-00-00' if only the year of the end date is known, and 'YYYY-MM-00' if year and month of the end date are known but the day is not."
|
556 |
+
},
|
557 |
+
},
|
558 |
+
"SpeciesName": {
|
559 |
+
"taxonomy": [Genus_species]}
|
560 |
+
}"""
|
561 |
+
|
562 |
+
structure = """{"Dictionary":
|
563 |
+
{
|
564 |
+
"catalog_number": "",
|
565 |
+
"genus": "",
|
566 |
+
"species": "",
|
567 |
+
"subspecies": "",
|
568 |
+
"variety": "",
|
569 |
+
"forma": "",
|
570 |
+
"country": "",
|
571 |
+
"state": "",
|
572 |
+
"county": "",
|
573 |
+
"locality_name": "",
|
574 |
+
"min_elevation": "",
|
575 |
+
"max_elevation": "",
|
576 |
+
"elevation_units": "",
|
577 |
+
"verbatim_coordinates": "",
|
578 |
+
"decimal_coordinates": "",
|
579 |
+
"datum": "",
|
580 |
+
"cultivated": "",
|
581 |
+
"habitat": "",
|
582 |
+
"plant_description": "",
|
583 |
+
"collectors": "",
|
584 |
+
"collector_number": "",
|
585 |
+
"determined_by": "",
|
586 |
+
"multiple_names": "",
|
587 |
+
"verbatim_date":"" ,
|
588 |
+
"date": "",
|
589 |
+
"end_date": ""
|
590 |
+
},
|
591 |
+
"SpeciesName": {"taxonomy": ""}}"""
|
592 |
+
|
593 |
+
prompt = f"""Please help me complete this text parsing task given the following rules and unstructured OCR text. Your task is to refactor the OCR text into a structured JSON dictionary that matches the structure specified in the following rules. Please follow the rules strictly.
|
594 |
+
The rules are:
|
595 |
+
{set_rules}
|
596 |
+
The unstructured OCR text is:
|
597 |
+
{self.OCR}
|
598 |
+
{dictionary_field_format_descriptions}
|
599 |
+
This is the JSON template that includes instructions for each key:
|
600 |
+
{json_template_rules}
|
601 |
+
Please populate the following JSON dictionary based on the rules and the unformatted OCR text:
|
602 |
+
{structure}
|
603 |
+
"""
|
604 |
+
|
605 |
+
xlsx_headers = ["catalog_number","genus","species","subspecies","variety","forma","country","state","county","locality_name","min_elevation","max_elevation","elevation_units","verbatim_coordinates","decimal_coordinates","datum","cultivated","habitat","plant_description","collectors","collector_number","determined_by","multiple_names","verbatim_date","date","end_date"]
|
606 |
+
|
607 |
+
return prompt, self.n_fields, xlsx_headers
|
608 |
+
|
609 |
+
#############################################################################################
|
610 |
+
#############################################################################################
|
611 |
+
#############################################################################################
|
612 |
+
#############################################################################################
|
613 |
+
# These are for dynamically creating your own prompts with n-columns
|
614 |
+
|
615 |
+
|
616 |
+
def prompt_SLTP(self, rules_config_path, OCR=None, is_palm=False):
|
617 |
+
self.OCR = OCR
|
618 |
+
|
619 |
+
self.rules_config_path = rules_config_path
|
620 |
+
self.rules_config = self.load_rules_config()
|
621 |
+
|
622 |
+
self.instructions = self.rules_config['instructions']
|
623 |
+
self.json_formatting_instructions = self.rules_config['json_formatting_instructions']
|
624 |
+
|
625 |
+
self.rules_list = self.rules_config['rules']
|
626 |
+
self.n_fields = len(self.rules_config['rules'])
|
627 |
+
|
628 |
+
# Set the rules for processing OCR into JSON format
|
629 |
+
self.rules = self.create_rules(is_palm)
|
630 |
+
|
631 |
+
self.structure, self.dictionary_structure = self.create_structure(is_palm)
|
632 |
+
|
633 |
+
''' between instructions and json_formatting_instructions. Made the prompt too long. Better performance without it
|
634 |
+
The unstructured OCR text is:
|
635 |
+
{self.OCR}
|
636 |
+
'''
|
637 |
+
if is_palm:
|
638 |
+
prompt = f"""Please help me complete this text parsing task given the following rules and unstructured OCR text. Your task is to refactor the OCR text into a structured JSON dictionary that matches the structure specified in the following rules. Please follow the rules strictly.
|
639 |
+
The rules are:
|
640 |
+
{self.instructions}
|
641 |
+
{self.json_formatting_instructions}
|
642 |
+
This is the JSON template that includes instructions for each key:
|
643 |
+
{self.rules}
|
644 |
+
The unstructured OCR text is:
|
645 |
+
{self.OCR}
|
646 |
+
Please populate the following JSON dictionary based on the rules and the unformatted OCR text:
|
647 |
+
{self.structure}
|
648 |
+
{self.structure}
|
649 |
+
{self.structure}
|
650 |
+
"""
|
651 |
+
else:
|
652 |
+
prompt = f"""Please help me complete this text parsing task given the following rules and unstructured OCR text. Your task is to refactor the OCR text into a structured JSON dictionary that matches the structure specified in the following rules. Please follow the rules strictly.
|
653 |
+
The rules are:
|
654 |
+
{self.instructions}
|
655 |
+
{self.json_formatting_instructions}
|
656 |
+
This is the JSON template that includes instructions for each key:
|
657 |
+
{self.rules}
|
658 |
+
The unstructured OCR text is:
|
659 |
+
{self.OCR}
|
660 |
+
Please populate the following JSON dictionary based on the rules and the unformatted OCR text:
|
661 |
+
{self.structure}
|
662 |
+
"""
|
663 |
+
# xlsx_headers = self.generate_xlsx_headers(is_palm)
|
664 |
+
|
665 |
+
# return prompt, self.PromptJSONModel, self.n_fields, xlsx_headers
|
666 |
+
return prompt, self.dictionary_structure
|
667 |
+
|
668 |
+
def load_rules_config(self):
|
669 |
+
with open(self.rules_config_path, 'r') as stream:
|
670 |
+
try:
|
671 |
+
return yaml.safe_load(stream)
|
672 |
+
except yaml.YAMLError as exc:
|
673 |
+
print(exc)
|
674 |
+
return None
|
675 |
+
|
676 |
+
def create_rules(self, is_palm=False):
|
677 |
+
dictionary_structure = {key: value for key, value in self.rules_list.items()}
|
678 |
+
|
679 |
+
# Convert the structure to a JSON string without indentation
|
680 |
+
structure_json_str = json.dumps(dictionary_structure, sort_keys=False)
|
681 |
+
return structure_json_str
|
682 |
+
|
683 |
+
def create_structure(self, is_palm=False):
|
684 |
+
# Create fields for the Pydantic model dynamically
|
685 |
+
fields = {key: (str, Field(default=value, description=value)) for key, value in self.rules_list.items()}
|
686 |
+
|
687 |
+
# Dynamically create the Pydantic model
|
688 |
+
DynamicJSONParsingModel = create_model('SLTPvA', **fields)
|
689 |
+
DynamicJSONParsingModel_use = DynamicJSONParsingModel()
|
690 |
+
|
691 |
+
# Define the structure for the "Dictionary" section
|
692 |
+
dictionary_fields = {key: (str, Field(default='', description="")) for key in self.rules_list.keys()}
|
693 |
+
|
694 |
+
# Dynamically create the "Dictionary" Pydantic model
|
695 |
+
PromptJSONModel = create_model('PromptJSONModel', **dictionary_fields)
|
696 |
+
|
697 |
+
# Convert the model to JSON string (for demonstration)
|
698 |
+
dictionary_structure = PromptJSONModel().dict()
|
699 |
+
structure_json_str = json.dumps(dictionary_structure, sort_keys=False, indent=4)
|
700 |
+
return structure_json_str, dictionary_structure
|
701 |
+
|
702 |
+
|
703 |
+
def generate_xlsx_headers(self, is_palm):
|
704 |
+
# Extract headers from the 'Dictionary' keys in the JSON template rules
|
705 |
+
if is_palm:
|
706 |
+
xlsx_headers = list(self.rules_list.keys())
|
707 |
+
return xlsx_headers
|
708 |
+
else:
|
709 |
+
xlsx_headers = list(self.rules_list.keys())
|
710 |
+
return xlsx_headers
|
711 |
+
|
712 |
+
def prompt_v2_custom_redo(self, incorrect_json, is_palm):
|
713 |
+
# Load the existing rules and structure
|
714 |
+
self.rules_config = self.load_rules_config()
|
715 |
+
# self.rules = self.create_rules(is_palm)
|
716 |
+
self.structure, self.dictionary_structure = self.create_structure(is_palm)
|
717 |
+
|
718 |
+
# Generate the prompt using the loaded rules and structure
|
719 |
+
if is_palm:
|
720 |
+
prompt = f"""The incorrectly formatted JSON dictionary below is not valid. It contains an error that prevents it from loading with the Python command json.loads().
|
721 |
+
The incorrectly formatted JSON dictionary below is the literal string returned by a previous function and the error may be caused by markdown formatting.
|
722 |
+
You need to return coorect JSON for the following dictionary. Most likely, a quotation mark inside of a field value has not been escaped properly with a backslash.
|
723 |
+
Given the input, please generate a JSON response. Please note that the response should not contain any special characters, including quotation marks (single ' or double \"), within the JSON values.
|
724 |
+
Escape all JSON control characters that appear in input including ampersand (&) and other control characters.
|
725 |
+
Ensure all key-value pairs in the JSON dictionary strictly adhere to the format and data types specified in the template.
|
726 |
+
Ensure the output JSON string is valid JSON format. It should not have trailing commas or unquoted keys.
|
727 |
+
The incorrectly formatted JSON dictionary: {incorrect_json}
|
728 |
+
The output JSON structure: {self.structure}
|
729 |
+
The output JSON structure: {self.structure}
|
730 |
+
The output JSON structure: {self.structure}
|
731 |
+
Please reformat the incorrectly formatted JSON dictionary given the output JSON structure: """
|
732 |
+
else:
|
733 |
+
prompt = f"""The incorrectly formatted JSON dictionary below is not valid. It contains an error that prevents it from loading with the Python command json.loads().
|
734 |
+
The incorrectly formatted JSON dictionary below is the literal string returned by a previous function and the error may be caused by markdown formatting.
|
735 |
+
You need to return coorect JSON for the following dictionary. Most likely, a quotation mark inside of a field value has not been escaped properly with a backslash.
|
736 |
+
Given the input, please generate a JSON response. Please note that the response should not contain any special characters, including quotation marks (single ' or double \"), within the JSON values.
|
737 |
+
Escape all JSON control characters that appear in input including ampersand (&) and other control characters.
|
738 |
+
Ensure all key-value pairs in the JSON dictionary strictly adhere to the format and data types specified in the template.
|
739 |
+
Ensure the output JSON string is valid JSON format. It should not have trailing commas or unquoted keys.
|
740 |
+
The incorrectly formatted JSON dictionary: {incorrect_json}
|
741 |
+
The output JSON structure: {self.structure}
|
742 |
+
Please reformat the incorrectly formatted JSON dictionary given the output JSON structure: """
|
743 |
+
return prompt
|
744 |
+
|
745 |
+
#############################################################################################
|
746 |
+
#############################################################################################
|
747 |
+
#############################################################################################
|
748 |
+
#############################################################################################
|
749 |
+
def prompt_gpt_redo_v1(self, incorrect_json):
|
750 |
+
structure = """Below is the correct JSON formatting. Modify the text to conform to the following format, fixing the incorrect JSON:
|
751 |
+
{"Dictionary":
|
752 |
+
{
|
753 |
+
"Catalog Number": [Catalog Number],
|
754 |
+
"Genus": [Genus],
|
755 |
+
"Species": [species],
|
756 |
+
"subspecies": [subspecies],
|
757 |
+
"variety": [variety],
|
758 |
+
"forma": [forma],
|
759 |
+
"Country": [Country],
|
760 |
+
"State": [State],
|
761 |
+
"County": [County],
|
762 |
+
"Locality Name": [Locality Name],
|
763 |
+
"Min Elevation": [Min Elevation],
|
764 |
+
"Max Elevation": [Max Elevation],
|
765 |
+
"Elevation Units": [Elevation Units],
|
766 |
+
"Verbatim Coordinates": [Verbatim Coordinates],
|
767 |
+
"Datum": [Datum],
|
768 |
+
"Cultivated": [Cultivated],
|
769 |
+
"Habitat": [Habitat],
|
770 |
+
"Collectors": [Collectors],
|
771 |
+
"Collector Number": [Collector Number],
|
772 |
+
"Verbatim Date": [Verbatim Date],
|
773 |
+
"Date": [Date],
|
774 |
+
"End Date": [End Date]
|
775 |
+
},
|
776 |
+
"SpeciesName": {"taxonomy": [Genus_species]}}"""
|
777 |
+
|
778 |
+
prompt = f"""This text is supposed to be JSON, but it contains an error that prevents it from loading with the Python command json.loads().
|
779 |
+
You need to return coorect JSON for the following dictionary. Most likely, a quotation mark inside of a field value has not been escaped properly with a backslash.
|
780 |
+
Given the input, please generate a JSON response. Please note that the response should not contain any special characters, including quotation marks (single ' or double \"), within the JSON values.
|
781 |
+
Escape all JSON control characters that appear in input including ampersand (&) and other control characters.
|
782 |
+
Ensure all key-value pairs in the JSON dictionary strictly adhere to the format and data types specified in the template.
|
783 |
+
Ensure the output JSON string is valid JSON format. It should not have trailing commas or unquoted keys.
|
784 |
+
The incorrectly formatted JSON dictionary: {incorrect_json}
|
785 |
+
The output JSON structure: {structure}
|
786 |
+
The refactored JSON disctionary: """
|
787 |
+
return prompt
|
788 |
+
|
789 |
+
def prompt_gpt_redo_v2(self, incorrect_json):
|
790 |
+
structure = """
|
791 |
+
{"Dictionary":{
|
792 |
+
"catalog_number": "",
|
793 |
+
"genus": "",
|
794 |
+
"species": "".
|
795 |
+
"subspecies": "",
|
796 |
+
"variety": "",
|
797 |
+
"forma":"",
|
798 |
+
"country": "",
|
799 |
+
"state": "",
|
800 |
+
"county": "",
|
801 |
+
"locality_name": "",
|
802 |
+
"min_elevation": "",
|
803 |
+
"max_elevation": "",
|
804 |
+
"elevation_units": "',
|
805 |
+
"verbatim_coordinates": "",
|
806 |
+
"decimal_coordinates": "",
|
807 |
+
"datum": "",
|
808 |
+
"cultivated": "",
|
809 |
+
"habitat": "",
|
810 |
+
"plant_description": "",
|
811 |
+
"collectors": "",
|
812 |
+
"collector_number": "",
|
813 |
+
"determined_by": "",
|
814 |
+
"multiple_names": "',
|
815 |
+
"verbatim_date": "",
|
816 |
+
"date": "",
|
817 |
+
"end_date": "",
|
818 |
+
},
|
819 |
+
"SpeciesName": {"taxonomy": [Genus_species]}}"""
|
820 |
+
|
821 |
+
prompt = f"""This text is supposed to be JSON, but it contains an error that prevents it from loading with the Python command json.loads().
|
822 |
+
You need to return coorect JSON for the following dictionary. Most likely, a quotation mark inside of a field value has not been escaped properly with a backslash.
|
823 |
+
Given the input, please generate a JSON response. Please note that the response should not contain any special characters, including quotation marks (single ' or double \"), within the JSON values.
|
824 |
+
Escape all JSON control characters that appear in input including ampersand (&) and other control characters.
|
825 |
+
Ensure all key-value pairs in the JSON dictionary strictly adhere to the format and data types specified in the template.
|
826 |
+
Ensure the output JSON string is valid JSON format. It should not have trailing commas or unquoted keys.
|
827 |
+
The incorrectly formatted JSON dictionary: {incorrect_json}
|
828 |
+
The output JSON structure: {structure}
|
829 |
+
The refactored JSON disctionary: """
|
830 |
+
return prompt
|
831 |
+
#####################################################################################################################################
|
832 |
+
#####################################################################################################################################
|
833 |
+
def prompt_v1_palm2(self, in_list, out_list, OCR=None):
|
834 |
+
self.OCR = OCR or self.OCR
|
835 |
+
set_rules = """1. Your job is to return a new dict based on the structure of the reference dict ref_dict and these are your rules.
|
836 |
+
2. You must look at ref_dict and refactor the new text called OCR to match the same formatting.
|
837 |
+
3. OCR contains unstructured text inside of [], use your knowledge to put the OCR text into the correct ref_dict column.
|
838 |
+
4. If OCR is mostly empty and contains substantially less text than the ref_dict examples, then only return "None" and skip all other steps.
|
839 |
+
5. If there is a field that does not have a direct proxy in the OCR text, you can fill it in based on your knowledge, but you cannot generate new information.
|
840 |
+
6. Never put text from the ref_dict values into the new dict, but you must use the headers from ref_dict.
|
841 |
+
7. There cannot be duplicate dictionary fields.
|
842 |
+
8. Only return the new dict, do not explain your answer.
|
843 |
+
9. Do not include quotation marks in content, only use quotation marks to represent values in dictionaries.
|
844 |
+
10. For GPS coordinates only use Decimal Degrees (D.D°)
|
845 |
+
11. "Given the input text, please generate a JSON response. Please note that the response should not contain any special characters, including quotation marks (single ' or double \"), within the JSON values."""
|
846 |
+
|
847 |
+
umich_all_asia_rules = """
|
848 |
+
"Catalog Number" - {"format": "[barcode]", "null_value": "", "description": the barcode identifier, typically a number with at least 6 digits, but fewer than 30 digits}
|
849 |
+
"Genus" - {"format": "[Genus]" or "[Family] indet" if no genus", "null_value": "", "description": taxonomic determination to genus, do captalize genus}
|
850 |
+
"Species"- {"format": "[species]" or "indet" if no species, "null_value": "", "description": taxonomic determination to species, do not captalize species}
|
851 |
+
"subspecies" - {"format": "[subspecies]", "null_value": "", "description": taxonomic determination to subspecies (subsp.)}
|
852 |
+
"variety" - {"format": "[variety]", "null_value": "", "description": taxonomic determination to variety (var)}
|
853 |
+
"forma" - {"format": "[form]", "null_value": "", "description": taxonomic determination to form (f.)}
|
854 |
+
|
855 |
+
"Country" - {"format": "[Country]", "null_value": "no data", "description": Country that corresponds to the current geographic location of collection; capitalize first letter of each word; use the entire location name even if an abreviation is given}
|
856 |
+
"State" - {"format": "[Adm. Division 1]", "null_value": "no data", "description": Administrative division 1 that corresponds to the current geographic location of collection; capitalize first letter of each word}
|
857 |
+
"County" - {"format": "[Adm. Division 2]", "null_value": "no data", "description": Administrative division 2 that corresponds to the current geographic location of collection; capitalize first letter of each word}
|
858 |
+
"Locality Name" - {"format": "verbatim", if no geographic info: "no data provided on label of catalog no: [######]", or if illegible: "locality present but illegible/not translated for catalog no: #######", or if no named locality: "no named locality for catalog no: #######", "description": "Description of geographic location or landscape"}
|
859 |
+
|
860 |
+
"Min Elevation" - {format: "elevation integer", "null_value": "","description": Elevation or altitude in meters, convert from feet to meters if 'm' or 'meters' is not in the text and round to integer, default field for elevation if a range is not given}
|
861 |
+
"Max Elevation" - {format: "elevation integer", "null_value": "","description": Elevation or altitude in meters, convert from feet to meters if 'm' or 'meters' is not in the text and round to integer, maximum elevation if there are two elevations listed but '' otherwise}
|
862 |
+
"Elevation Units" - {format: "m", "null_value": "","description": "m" only if an elevation is present}
|
863 |
+
|
864 |
+
"Verbatim Coordinates" - {"format": "[Lat, Long | UTM | TRS]", "null_value": "", "description": Convert coordinates to Decimal Degrees (D.D°) format, do not use Minutes, Seconds or quotation marks}
|
865 |
+
|
866 |
+
"Datum" - {"format": "[WGS84, NAD23 etc.]", "null_value": "not present", "description": Datum of coordinates on label; "" is GPS coordinates are not in OCR}
|
867 |
+
"Cultivated" - {"format": "yes", "null_value": "", "description": Indicates if specimen was grown in cultivation}
|
868 |
+
"Habitat" - {"format": "verbatim", "null_value": "", "description": Description of habitat or location where specimen was collected, ignore descriptions of the plant itself}
|
869 |
+
"Collectors" - {"format": "[Collector]", "null_value": "not present", "description": Full name of person (i.e., agent) who collected the specimen; if more than one person then separate the names with commas}
|
870 |
+
"Collector Number" - {"format": "[Collector No.]", "null_value": "s.n.", "description": Sequential number assigned to collection, associated with the collector}
|
871 |
+
"Verbatim Date" - {"format": "verbatim", "null_value": "s.d.", "description": Date of collection exactly as it appears on the label}
|
872 |
+
"Date" - {"format": "[yyyy-mm-dd]", "null_value": "", "description": Date of collection formatted as year, month, and day; zeros may be used for unknown values i.e. 0000-00-00 if no date, YYYY-00-00 if only year, YYYY-MM-00 if no day}
|
873 |
+
"End Date" - {"format": "[yyyy-mm-dd]", "null_value": "", "description": If date range is listed, later date of collection range}
|
874 |
+
"""
|
875 |
+
|
876 |
+
prompt = f"""Given the following set of rules:
|
877 |
+
|
878 |
+
set_rules = {set_rules}
|
879 |
+
|
880 |
+
Some dict fields have special requirements listed below. First is the column header. After the - is the format. Do not include the instructions with your response:
|
881 |
+
|
882 |
+
requirements = {umich_all_asia_rules}
|
883 |
+
|
884 |
+
Given the input, please generate a JSON response. Please note that the response should not contain any special characters, including quotation marks (single ' or double \"), within the JSON values.
|
885 |
+
|
886 |
+
input: {in_list[0]}
|
887 |
+
|
888 |
+
output: {out_list[0]}
|
889 |
+
|
890 |
+
input: {in_list[1]}
|
891 |
+
|
892 |
+
output: {out_list[1]}
|
893 |
+
|
894 |
+
input: {in_list[2]}
|
895 |
+
|
896 |
+
output: {out_list[2]}
|
897 |
+
|
898 |
+
input: {self.OCR}
|
899 |
+
|
900 |
+
output:"""
|
901 |
+
|
902 |
+
return prompt
|
903 |
+
|
904 |
+
def prompt_v1_palm2_noDomainKnowledge(self, OCR=None):
|
905 |
+
self.OCR = OCR or self.OCR
|
906 |
+
set_rules = """1. Your job is to return a new dict based on the structure of the reference dict ref_dict and these are your rules.
|
907 |
+
2. You must look at ref_dict and refactor the new text called OCR to match the same formatting.
|
908 |
+
3. OCR contains unstructured text inside of [], use your knowledge to put the OCR text into the correct ref_dict column.
|
909 |
+
4. If OCR is mostly empty and contains substantially less text than the ref_dict examples, then only return "None" and skip all other steps.
|
910 |
+
5. If there is a field that does not have a direct proxy in the OCR text, you can fill it in based on your knowledge, but you cannot generate new information.
|
911 |
+
6. Never put text from the ref_dict values into the new dict, but you must use the headers from ref_dict.
|
912 |
+
7. There cannot be duplicate dictionary fields.
|
913 |
+
8. Only return the new dict, do not explain your answer.
|
914 |
+
9. Do not include quotation marks in content, only use quotation marks to represent values in dictionaries.
|
915 |
+
10. For GPS coordinates only use Decimal Degrees (D.D°)
|
916 |
+
11. "Given the input text, please generate a JSON response. Please note that the response should not contain any special characters, including quotation marks (single ' or double \"), within the JSON values."""
|
917 |
+
|
918 |
+
umich_all_asia_rules = """
|
919 |
+
"Catalog Number" - {"format": "barcode", "null_value": "", "description": the barcode identifier, typically a number with at least 6 digits, but fewer than 30 digits}
|
920 |
+
"Genus" - {"format": "Genus" or "Family indet" if no genus", "null_value": "", "description": taxonomic determination to genus, do captalize genus}
|
921 |
+
"Species"- {"format": "species" or "indet" if no species, "null_value": "", "description": taxonomic determination to species, do not captalize species}
|
922 |
+
"subspecies" - {"format": "subspecies", "null_value": "", "description": taxonomic determination to subspecies (subsp.)}
|
923 |
+
"variety" - {"format": "variety", "null_value": "", "description": taxonomic determination to variety (var)}
|
924 |
+
"forma" - {"format": "form", "null_value": "", "description": taxonomic determination to form (f.)}
|
925 |
+
|
926 |
+
"Country" - {"format": "Country", "null_value": "no data", "description": Country that corresponds to the current geographic location of collection; capitalize first letter of each word; use the entire location name even if an abreviation is given}
|
927 |
+
"State" - {"format": "Adm. Division 1", "null_value": "no data", "description": Administrative division 1 that corresponds to the current geographic location of collection; capitalize first letter of each word}
|
928 |
+
"County" - {"format": "Adm. Division 2", "null_value": "no data", "description": Administrative division 2 that corresponds to the current geographic location of collection; capitalize first letter of each word}
|
929 |
+
"Locality Name" - {"format": "verbatim", if no geographic info: "no data provided on label of catalog no: ######", or if illegible: "locality present but illegible/not translated for catalog no: #######", or if no named locality: "no named locality for catalog no: #######", "description": "Description of geographic location or landscape"}
|
930 |
+
|
931 |
+
"Min Elevation" - {format: "elevation integer", "null_value": "","description": Elevation or altitude in meters, convert from feet to meters if 'm' or 'meters' is not in the text and round to integer, default field for elevation if a range is not given}
|
932 |
+
"Max Elevation" - {format: "elevation integer", "null_value": "","description": Elevation or altitude in meters, convert from feet to meters if 'm' or 'meters' is not in the text and round to integer, maximum elevation if there are two elevations listed but '' otherwise}
|
933 |
+
"Elevation Units" - {format: "m", "null_value": "","description": "m" only if an elevation is present}
|
934 |
+
|
935 |
+
"Verbatim Coordinates" - {"format": "Lat, Long, UTM, TRS", "null_value": "", "description": Convert coordinates to Decimal Degrees (D.D°) format, do not use Minutes, Seconds or quotation marks}
|
936 |
+
|
937 |
+
"Datum" - {"format": "WGS84, NAD23 etc.", "null_value": "not present", "description": Datum of coordinates on label; "" is GPS coordinates are not in OCR}
|
938 |
+
"Cultivated" - {"format": "yes", "null_value": "", "description": Indicates if specimen was grown in cultivation}
|
939 |
+
"Habitat" - {"format": "verbatim", "null_value": "", "description": Description of habitat or location where specimen was collected, ignore descriptions of the plant itself}
|
940 |
+
"Collectors" - {"format": "Collector", "null_value": "not present", "description": Full name of person (i.e., agent) who collected the specimen; if more than one person then separate the names with commas}
|
941 |
+
"Collector Number" - {"format": "Collector No.", "null_value": "s.n.", "description": Sequential number assigned to collection, associated with the collector}
|
942 |
+
"Verbatim Date" - {"format": "verbatim", "null_value": "s.d.", "description": Date of collection exactly as it appears on the label}
|
943 |
+
"Date" - {"format": "yyyy-mm-dd", "null_value": "", "description": Date of collection formatted as year, month, and day; zeros may be used for unknown values i.e. 0000-00-00 if no date, YYYY-00-00 if only year, YYYY-MM-00 if no day}
|
944 |
+
"End Date" - {"format": "yyyy-mm-dd", "null_value": "", "description": If date range is listed, later date of collection range}
|
945 |
+
"""
|
946 |
+
structure = """{
|
947 |
+
"Catalog Number": "",
|
948 |
+
"Genus": "",
|
949 |
+
"Species": "",
|
950 |
+
"subspecies": "",
|
951 |
+
"variety": "",
|
952 |
+
"forma": "",
|
953 |
+
"Country": "",
|
954 |
+
"State": "",
|
955 |
+
"County": "",
|
956 |
+
"Locality Name": "",
|
957 |
+
"Min Elevation": "",
|
958 |
+
"Max Elevation": "",
|
959 |
+
"Elevation Units": "",
|
960 |
+
"Verbatim Coordinates": "",
|
961 |
+
"Datum": "",
|
962 |
+
"Cultivated": "",
|
963 |
+
"Habitat": "",
|
964 |
+
"Collectors": "",
|
965 |
+
"Collector Number": "",
|
966 |
+
"Verbatim Date": "",
|
967 |
+
"Date": "",
|
968 |
+
"End Date": "",
|
969 |
+
}"""
|
970 |
+
# structure = """{
|
971 |
+
# "Catalog Number": [Catalog Number],
|
972 |
+
# "Genus": [Genus],
|
973 |
+
# "Species": [species],
|
974 |
+
# "subspecies": [subspecies],
|
975 |
+
# "variety": [variety],
|
976 |
+
# "forma": [forma],
|
977 |
+
# "Country": [Country],
|
978 |
+
# "State": [State],
|
979 |
+
# "County": [County],
|
980 |
+
# "Locality Name": [Locality Name],
|
981 |
+
# "Min Elevation": [Min Elevation],
|
982 |
+
# "Max Elevation": [Max Elevation],
|
983 |
+
# "Elevation Units": [Elevation Units],
|
984 |
+
# "Verbatim Coordinates": [Verbatim Coordinates],
|
985 |
+
# "Datum": [Datum],
|
986 |
+
# "Cultivated": [Cultivated],
|
987 |
+
# "Habitat": [Habitat],
|
988 |
+
# "Collectors": [Collectors],
|
989 |
+
# "Collector Number": [Collector Number],
|
990 |
+
# "Verbatim Date": [Verbatim Date],
|
991 |
+
# "Date": [Date],
|
992 |
+
# "End Date": [End Date]
|
993 |
+
# }"""
|
994 |
+
|
995 |
+
prompt = f"""Given the following set of rules:
|
996 |
+
set_rules = {set_rules}
|
997 |
+
Some dict fields have special requirements listed below. First is the column header. After the - is the format. Do not include the instructions with your response:
|
998 |
+
requirements = {umich_all_asia_rules}
|
999 |
+
Given the input, please generate a JSON response. Please note that the response should not contain any special characters, including quotation marks (single ' or double \"), within the JSON values.
|
1000 |
+
The input unformatted OCR text: {self.OCR}
|
1001 |
+
The output JSON structure: {structure}
|
1002 |
+
The output JSON structure: {structure}
|
1003 |
+
The output JSON structure: {structure}
|
1004 |
+
The refactored JSON disctionary:"""
|
1005 |
+
|
1006 |
+
return prompt
|
1007 |
+
|
1008 |
+
def prompt_v2_palm2(self, OCR=None):
|
1009 |
+
self.OCR = OCR or self.OCR
|
1010 |
+
self.n_fields = 26 or self.n_fields
|
1011 |
+
|
1012 |
+
set_rules = """
|
1013 |
+
1. Refactor the unstructured OCR text into a dictionary based on the JSON structure outlined below.
|
1014 |
+
2. You should map the unstructured OCR text to the appropriate JSON key and then populate the field based on its rules.
|
1015 |
+
3. Some JSON key fields are permitted to remain empty if the corresponding information is not found in the unstructured OCR text.
|
1016 |
+
4. Ignore any information in the OCR text that doesn't fit into the defined JSON structure.
|
1017 |
+
5. Duplicate dictionary fields are not allowed.
|
1018 |
+
6. Ensure that all JSON keys are in lowercase.
|
1019 |
+
7. Ensure that new JSON field values follow sentence case capitalization.
|
1020 |
+
8. Ensure all key-value pairs in the JSON dictionary strictly adhere to the format and data types specified in the template.
|
1021 |
+
9. Ensure the output JSON string is valid JSON format. It should not have trailing commas or unquoted keys.
|
1022 |
+
10. Only return a JSON dictionary represented as a string. You should not explain your answer.
|
1023 |
+
"""
|
1024 |
+
|
1025 |
+
dictionary_field_format_descriptions = """
|
1026 |
+
The next section of instructions outlines how to format the JSON dictionary. The keys are the same as those of the final formatted JSON object.
|
1027 |
+
For each key there is a format requirement that specifies how to transcribe the information for that key.
|
1028 |
+
The possible formatting options are:
|
1029 |
+
1. "verbatim transcription" - field is populated with verbatim text from the unformatted OCR.
|
1030 |
+
2. "spell check transcription" - field is populated with spelling corrected text from the unformatted OCR.
|
1031 |
+
3. "boolean yes no" - field is populated with only yes or no.
|
1032 |
+
4. "integer" - field is populated with only an integer.
|
1033 |
+
5. "[list]" - field is populated from one of the values in the list.
|
1034 |
+
6. "yyyy-mm-dd" - field is populated with a date in the format year-month-day.
|
1035 |
+
The desired null value is also given. Populate the field with the null value of the information for that key is not present in the unformatted OCR text.
|
1036 |
+
"""
|
1037 |
+
|
1038 |
+
json_template_rules = """
|
1039 |
+
{
|
1040 |
+
"catalog_number": {
|
1041 |
+
"format": "verbatim transcription",
|
1042 |
+
"null_value": "",
|
1043 |
+
"description": "The barcode identifier, typically a number with at least 6 digits, but fewer than 30 digits."
|
1044 |
+
},
|
1045 |
+
"genus": {
|
1046 |
+
"format": "verbatim transcription",
|
1047 |
+
"null_value": "",
|
1048 |
+
"description": "Taxonomic determination to genus. Genus must be capitalized. If genus is not present use the taxonomic family name followed by the word 'indet'."
|
1049 |
+
},
|
1050 |
+
"species": {
|
1051 |
+
"format": "verbatim transcription",
|
1052 |
+
"null_value": "",
|
1053 |
+
"description": "Taxonomic determination to species, do not capitalize species."
|
1054 |
+
},
|
1055 |
+
"subspecies": {
|
1056 |
+
"format": "verbatim transcription",
|
1057 |
+
"null_value": "",
|
1058 |
+
"description": "Taxonomic determination to subspecies (subsp.)."
|
1059 |
+
},
|
1060 |
+
"variety": {
|
1061 |
+
"format": "verbatim transcription",
|
1062 |
+
"null_value": "",
|
1063 |
+
"description": "Taxonomic determination to variety (var)."
|
1064 |
+
},
|
1065 |
+
"forma": {
|
1066 |
+
"format": "verbatim transcription",
|
1067 |
+
"null_value": "",
|
1068 |
+
"description": "Taxonomic determination to form (f.)."
|
1069 |
+
},
|
1070 |
+
"country": {
|
1071 |
+
"format": "spell check transcription",
|
1072 |
+
"null_value": "",
|
1073 |
+
"description": "Country that corresponds to the current geographic location of collection. Capitalize first letter of each word. If abbreviation is given populate field with the full spelling of the country's name. Use sentence-case to capitalize proper nouns."
|
1074 |
+
},
|
1075 |
+
"state": {
|
1076 |
+
"format": "spell check transcription",
|
1077 |
+
"null_value": "",
|
1078 |
+
"description": "Administrative division 1 that corresponds to the current geographic location of collection. Capitalize first letter of each word. Administrative division 1 is equivalent to a U.S. State. Use sentence-case to capitalize proper nouns."
|
1079 |
+
},
|
1080 |
+
"county": {
|
1081 |
+
"format": "spell check transcription",
|
1082 |
+
"null_value": "",
|
1083 |
+
"description": "Administrative division 2 that corresponds to the current geographic location of collection; capitalize first letter of each word. Administrative division 2 is equivalent to a U.S. county, parish, borough. Use sentence-case to capitalize proper nouns."
|
1084 |
+
},
|
1085 |
+
"locality_name": {
|
1086 |
+
"format": "verbatim transcription",
|
1087 |
+
"null_value": "",
|
1088 |
+
"description": "Description of geographic location, landscape, landmarks, regional features, nearby places, or any contextual information aiding in pinpointing the exact origin or site of the specimen. Use sentence-case to capitalize proper nouns."
|
1089 |
+
},
|
1090 |
+
"min_elevation": {
|
1091 |
+
"format": "integer",
|
1092 |
+
"null_value": "",
|
1093 |
+
"description": "Minimum elevation or altitude in meters. Only if units are explicit then convert from feet ('ft' or 'ft.' or 'feet') to meters ('m' or 'm.' or 'meters'). Round to integer."
|
1094 |
+
},
|
1095 |
+
"max_elevation": {
|
1096 |
+
"format": "integer",
|
1097 |
+
"null_value": "",
|
1098 |
+
"description": "Maximum elevation or altitude in meters. If only one elevation is present, then max_elevation should be set to the null_value. Only if units are explicit then convert from feet ('ft' or 'ft.' or 'feet') to meters ('m' or 'm.' or 'meters'). Round to integer."
|
1099 |
+
},
|
1100 |
+
"elevation_units": {
|
1101 |
+
"format": "spell check transcription",
|
1102 |
+
"null_value": "",
|
1103 |
+
"description": "Elevation units must be meters. If min_elevation field is populated, then elevation_units: 'm'. Otherwise elevation_units: ''"
|
1104 |
+
},
|
1105 |
+
"verbatim_coordinates": {
|
1106 |
+
"format": "verbatim transcription",
|
1107 |
+
"null_value": "",
|
1108 |
+
"description": "Verbatim location coordinates as they appear on the label. Do not convert formats. Possible coordinate types are one of [Lat, Long, UTM, TRS]."
|
1109 |
+
},
|
1110 |
+
"decimal_coordinates": {
|
1111 |
+
"format": "spell check transcription",
|
1112 |
+
"null_value": "",
|
1113 |
+
"description": "Correct and convert the verbatim location coordinates to conform with the decimal degrees GPS coordinate format."
|
1114 |
+
},
|
1115 |
+
"datum": {
|
1116 |
+
"format": "[WGS84, WGS72, WGS66, WGS60, NAD83, NAD27, OSGB36, ETRS89, ED50, GDA94, JGD2011, Tokyo97, KGD2002, TWD67, TWD97, BJS54, XAS80, GCJ-02, BD-09, PZ-90.11, GTRF, CGCS2000, ITRF88, ITRF89, ITRF90, ITRF91, ITRF92, ITRF93, ITRF94, ITRF96, ITRF97, ITRF2000, ITRF2005, ITRF2008, ITRF2014, Hong Kong Principal Datum, SAD69]",
|
1117 |
+
"null_value": "",
|
1118 |
+
"description": "Datum of location coordinates. Possible values are include in the format list. Leave field blank if unclear."
|
1119 |
+
},
|
1120 |
+
"cultivated": {
|
1121 |
+
"format": "boolean yes no",
|
1122 |
+
"null_value": "",
|
1123 |
+
"description": "Cultivated plants are intentionally grown by humans. In text descriptions, look for planting dates, garden locations, ornamental, cultivar names, garden, or farm to indicate cultivated plant."
|
1124 |
+
},
|
1125 |
+
"habitat": {
|
1126 |
+
"format": "verbatim transcription",
|
1127 |
+
"null_value": "",
|
1128 |
+
"description": "Description of a plant's habitat or the location where the specimen was collected. Ignore descriptions of the plant itself. Use sentence-case to capitalize proper nouns."
|
1129 |
+
},
|
1130 |
+
"plant_description": {
|
1131 |
+
"format": "verbatim transcription",
|
1132 |
+
"null_value": "",
|
1133 |
+
"description": "Description of plant features such as leaf shape, size, color, stem texture, height, flower structure, scent, fruit or seed characteristics, root system type, overall growth habit and form, any notable aroma or secretions, presence of hairs or bristles, and any other distinguishing morphological or physiological characteristics. Use sentence-case to capitalize proper nouns."
|
1134 |
+
},
|
1135 |
+
"collectors": {
|
1136 |
+
"format": "verbatim transcription",
|
1137 |
+
"null_value": "not present",
|
1138 |
+
"description": "Full name(s) of the individual(s) responsible for collecting the specimen. Use sentence-case to capitalize proper nouns. When multiple collectors are involved, their names should be separated by commas."
|
1139 |
+
},
|
1140 |
+
"collector_number": {
|
1141 |
+
"format": "verbatim transcription",
|
1142 |
+
"null_value": "s.n.",
|
1143 |
+
"description": "Unique identifier or number that denotes the specific collecting event and associated with the collector."
|
1144 |
+
},
|
1145 |
+
"determined_by": {
|
1146 |
+
"format": "verbatim transcription",
|
1147 |
+
"null_value": "",
|
1148 |
+
"description": "Full name of the individual responsible for determining the taxanomic name of the specimen. Use sentence-case to capitalize proper nouns. Sometimes the name will be near to the characters 'det' to denote determination. This name may be isolated from other names in the unformatted OCR text."
|
1149 |
+
},
|
1150 |
+
"multiple_names": {
|
1151 |
+
"format": "boolean yes no",
|
1152 |
+
"null_value": "",
|
1153 |
+
"description": "Indicate whether multiple people or collector names are present in the unformatted OCR text. Use sentence-case to capitalize proper nouns. If you see more than one person's name the value is 'yes'; otherwise the value is 'no'."
|
1154 |
+
},
|
1155 |
+
"verbatim_date": {
|
1156 |
+
"format": "verbatim transcription",
|
1157 |
+
"null_value": "s.d.",
|
1158 |
+
"description": "Date of collection exactly as it appears on the label. Do not change the format or correct typos."
|
1159 |
+
},
|
1160 |
+
"date": {
|
1161 |
+
"format": "yyyy-mm-dd",
|
1162 |
+
"null_value": "",
|
1163 |
+
"description": "Date the specimen was collected formatted as year-month-day. If specific components of the date are unknown, they should be replaced with zeros. Examples: '0000-00-00' if the entire date is unknown, 'YYYY-00-00' if only the year is known, and 'YYYY-MM-00' if year and month are known but day is not."
|
1164 |
+
},
|
1165 |
+
"end_date": {
|
1166 |
+
"format": "yyyy-mm-dd",
|
1167 |
+
"null_value": "",
|
1168 |
+
"description": "If a date range is provided, this represents the later or ending date of the collection period, formatted as year-month-day. If specific components of the date are unknown, they should be replaced with zeros. Examples: '0000-00-00' if the entire end date is unknown, 'YYYY-00-00' if only the year of the end date is known, and 'YYYY-MM-00' if year and month of the end date are known but the day is not."
|
1169 |
+
},
|
1170 |
+
}"""
|
1171 |
+
|
1172 |
+
structure = """{"catalog_number": "",
|
1173 |
+
"genus": "",
|
1174 |
+
"species": "".
|
1175 |
+
"subspecies": "",
|
1176 |
+
"variety": "",
|
1177 |
+
"forma":"",
|
1178 |
+
"country": "",
|
1179 |
+
"state": "",
|
1180 |
+
"county": "",
|
1181 |
+
"locality_name": "",
|
1182 |
+
"min_elevation": "",
|
1183 |
+
"max_elevation": "",
|
1184 |
+
"elevation_units": "',
|
1185 |
+
"verbatim_coordinates": "",
|
1186 |
+
"decimal_coordinates": "",
|
1187 |
+
"datum": "",
|
1188 |
+
"cultivated": "",
|
1189 |
+
"habitat": "",
|
1190 |
+
"plant_description": "",
|
1191 |
+
"collectors": "",
|
1192 |
+
"collector_number": "",
|
1193 |
+
"determined_by": "",
|
1194 |
+
"multiple_names": "',
|
1195 |
+
"verbatim_date": "",
|
1196 |
+
"date": "",
|
1197 |
+
"end_date": "",
|
1198 |
+
}"""
|
1199 |
+
# structure = """{"catalog_number": [Catalog Number],
|
1200 |
+
# "genus": [Genus],
|
1201 |
+
# "species": [species],
|
1202 |
+
# "subspecies": [subspecies],
|
1203 |
+
# "variety": [variety],
|
1204 |
+
# "forma": [forma],
|
1205 |
+
# "country": [Country],
|
1206 |
+
# "state": [State],
|
1207 |
+
# "county": [County],
|
1208 |
+
# "locality_name": [Locality Name],
|
1209 |
+
# "min_elevation": [Min Elevation],
|
1210 |
+
# "max_elevation": [Max Elevation],
|
1211 |
+
# "elevation_units": [Elevation Units],
|
1212 |
+
# "verbatim_coordinates": [Verbatim Coordinates],
|
1213 |
+
# "decimal_coordinates": [Decimal Coordinates],
|
1214 |
+
# "datum": [Datum],
|
1215 |
+
# "cultivated": [boolean yes no],
|
1216 |
+
# "habitat": [Habitat Description],
|
1217 |
+
# "plant_description": [Plant Description],
|
1218 |
+
# "collectors": [Name(s) of Collectors],
|
1219 |
+
# "collector_number": [Collector Number],
|
1220 |
+
# "determined_by": [Name(s) of Taxonomist],
|
1221 |
+
# "multiple_names": [boolean yes no],
|
1222 |
+
# "verbatim_date": [Verbatim Date],
|
1223 |
+
# "date": [yyyy-mm-dd],
|
1224 |
+
# "end_date": [yyyy-mm-dd],
|
1225 |
+
# }"""
|
1226 |
+
|
1227 |
+
prompt = f"""Please help me complete this text parsing task given the following rules and unstructured OCR text. Your task is to refactor the OCR text into a structured JSON dictionary that matches the structure specified in the following rules. Please follow the rules strictly.
|
1228 |
+
The rules are:
|
1229 |
+
{set_rules}
|
1230 |
+
The unstructured OCR text is:
|
1231 |
+
{self.OCR}
|
1232 |
+
{dictionary_field_format_descriptions}
|
1233 |
+
This is the JSON template that includes instructions for each key:
|
1234 |
+
{json_template_rules}
|
1235 |
+
Please populate the following JSON dictionary based on the rules and the unformatted OCR text. The square brackets denote the locations that you should place the new structured text:
|
1236 |
+
{structure}
|
1237 |
+
{structure}
|
1238 |
+
{structure}
|
1239 |
+
"""
|
1240 |
+
|
1241 |
+
return prompt
|
1242 |
+
|
1243 |
+
def prompt_palm_redo_v1(self, incorrect_json):
|
1244 |
+
structure = """{
|
1245 |
+
"Catalog Number": [Catalog Number],
|
1246 |
+
"Genus": [Genus],
|
1247 |
+
"Species": [species],
|
1248 |
+
"subspecies": [subspecies],
|
1249 |
+
"variety": [variety],
|
1250 |
+
"forma": [forma],
|
1251 |
+
"Country": [Country],
|
1252 |
+
"State": [State],
|
1253 |
+
"County": [County],
|
1254 |
+
"Locality Name": [Locality Name],
|
1255 |
+
"Min Elevation": [Min Elevation],
|
1256 |
+
"Max Elevation": [Max Elevation],
|
1257 |
+
"Elevation Units": [Elevation Units],
|
1258 |
+
"Verbatim Coordinates": [Verbatim Coordinates],
|
1259 |
+
"Datum": [Datum],
|
1260 |
+
"Cultivated": [Cultivated],
|
1261 |
+
"Habitat": [Habitat],
|
1262 |
+
"Collectors": [Collectors],
|
1263 |
+
"Collector Number": [Collector Number],
|
1264 |
+
"Verbatim Date": [Verbatim Date],
|
1265 |
+
"Date": [Date],
|
1266 |
+
"End Date": [End Date]
|
1267 |
+
}"""
|
1268 |
+
|
1269 |
+
prompt = f"""This text is supposed to be JSON, but it contains an error that prevents it from loading with the Python command json.loads().
|
1270 |
+
You need to return coorect JSON for the following dictionary. Most likely, a quotation mark inside of a field value has not been escaped properly with a backslash.
|
1271 |
+
Given the input, please generate a JSON response. Please note that the response should not contain any special characters, including quotation marks (single ' or double \"), within the JSON values.
|
1272 |
+
Escape all JSON control characters that appear in input including ampersand (&) and other control characters.
|
1273 |
+
Ensure all key-value pairs in the JSON dictionary strictly adhere to the format and data types specified in the template.
|
1274 |
+
Ensure the output JSON string is valid JSON format. It should not have trailing commas or unquoted keys.
|
1275 |
+
The incorrectly formatted JSON dictionary: {incorrect_json}
|
1276 |
+
The output JSON structure: {structure}
|
1277 |
+
The output JSON structure: {structure}
|
1278 |
+
The output JSON structure: {structure}
|
1279 |
+
The refactored JSON disctionary: """
|
1280 |
+
return prompt
|
1281 |
+
|
1282 |
+
def prompt_palm_redo_v2(self, incorrect_json):
|
1283 |
+
structure = """{"catalog_number": "",
|
1284 |
+
"genus": "",
|
1285 |
+
"species": "".
|
1286 |
+
"subspecies": "",
|
1287 |
+
"variety": "",
|
1288 |
+
"forma":"",
|
1289 |
+
"country": "",
|
1290 |
+
"state": "",
|
1291 |
+
"county": "",
|
1292 |
+
"locality_name": "",
|
1293 |
+
"min_elevation": "",
|
1294 |
+
"max_elevation": "",
|
1295 |
+
"elevation_units": "',
|
1296 |
+
"verbatim_coordinates": "",
|
1297 |
+
"decimal_coordinates": "",
|
1298 |
+
"datum": "",
|
1299 |
+
"cultivated": "",
|
1300 |
+
"habitat": "",
|
1301 |
+
"plant_description": "",
|
1302 |
+
"collectors": "",
|
1303 |
+
"collector_number": "",
|
1304 |
+
"determined_by": "",
|
1305 |
+
"multiple_names": "',
|
1306 |
+
"verbatim_date": "",
|
1307 |
+
"date": "",
|
1308 |
+
"end_date": "",
|
1309 |
+
}"""
|
1310 |
+
|
1311 |
+
prompt = f"""This text is supposed to be JSON, but it contains an error that prevents it from loading with the Python command json.loads().
|
1312 |
+
You need to return coorect JSON for the following dictionary. Most likely, a quotation mark inside of a field value has not been escaped properly with a backslash.
|
1313 |
+
Given the input, please generate a JSON response. Please note that the response should not contain any special characters, including quotation marks (single ' or double \"), within the JSON values.
|
1314 |
+
Escape all JSON control characters that appear in input including ampersand (&) and other control characters.
|
1315 |
+
Ensure all key-value pairs in the JSON dictionary strictly adhere to the format and data types specified in the template.
|
1316 |
+
Ensure the output JSON string is valid JSON format. It should not have trailing commas or unquoted keys.
|
1317 |
+
The incorrectly formatted JSON dictionary: {incorrect_json}
|
1318 |
+
The output JSON structure: {structure}
|
1319 |
+
The output JSON structure: {structure}
|
1320 |
+
The output JSON structure: {structure}
|
1321 |
+
The refactored JSON disctionary: """
|
1322 |
+
return prompt
|
vouchervision/LLM_GoogleGemini.py
ADDED
@@ -0,0 +1,152 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os, time
|
2 |
+
import vertexai
|
3 |
+
from vertexai.preview.generative_models import GenerativeModel
|
4 |
+
from vertexai.generative_models._generative_models import HarmCategory, HarmBlockThreshold
|
5 |
+
from langchain.output_parsers import RetryWithErrorOutputParser
|
6 |
+
from langchain.schema import HumanMessage
|
7 |
+
from langchain.prompts import PromptTemplate
|
8 |
+
from langchain_core.output_parsers import JsonOutputParser
|
9 |
+
from langchain_google_genai import ChatGoogleGenerativeAI
|
10 |
+
|
11 |
+
from vouchervision.utils_LLM import SystemLoadMonitor, count_tokens
|
12 |
+
from vouchervision.utils_LLM_JSON_validation import validate_and_align_JSON_keys_with_template
|
13 |
+
from vouchervision.utils_taxonomy_WFO import validate_taxonomy_WFO
|
14 |
+
from vouchervision.utils_geolocate_HERE import validate_coordinates_here
|
15 |
+
|
16 |
+
class GoogleGeminiHandler:
|
17 |
+
|
18 |
+
RETRY_DELAY = 10 # Wait 10 seconds before retrying
|
19 |
+
MAX_RETRIES = 3 # Maximum number of retries
|
20 |
+
TOKENIZER_NAME = 'gpt-4'
|
21 |
+
VENDOR = 'google'
|
22 |
+
STARTING_TEMP = 0.5
|
23 |
+
|
24 |
+
def __init__(self, logger, model_name, JSON_dict_structure):
|
25 |
+
self.logger = logger
|
26 |
+
self.model_name = model_name
|
27 |
+
self.JSON_dict_structure = JSON_dict_structure
|
28 |
+
|
29 |
+
self.starting_temp = float(self.STARTING_TEMP)
|
30 |
+
self.temp_increment = float(0.2)
|
31 |
+
self.adjust_temp = self.starting_temp
|
32 |
+
|
33 |
+
self.monitor = SystemLoadMonitor(logger)
|
34 |
+
|
35 |
+
self.parser = JsonOutputParser()
|
36 |
+
|
37 |
+
# Define the prompt template
|
38 |
+
self.prompt = PromptTemplate(
|
39 |
+
template="Answer the user query.\n{format_instructions}\n{query}\n",
|
40 |
+
input_variables=["query"],
|
41 |
+
partial_variables={"format_instructions": self.parser.get_format_instructions()},
|
42 |
+
)
|
43 |
+
self._set_config()
|
44 |
+
|
45 |
+
|
46 |
+
def _set_config(self):
|
47 |
+
# os.environ['GOOGLE_API_KEY'] # Must be set too for the retry call, set in VoucherVision class along with other API Keys
|
48 |
+
# vertexai.init(project=os.environ['PALM_PROJECT_ID'], location=os.environ['PALM_LOCATION'])
|
49 |
+
self.config = {
|
50 |
+
"max_output_tokens": 1024,
|
51 |
+
"temperature": self.starting_temp,
|
52 |
+
"top_p": 1
|
53 |
+
}
|
54 |
+
self.safety_settings = {
|
55 |
+
HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE,
|
56 |
+
HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_NONE,
|
57 |
+
HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_NONE,
|
58 |
+
HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_NONE,
|
59 |
+
}
|
60 |
+
self._build_model_chain_parser()
|
61 |
+
|
62 |
+
def _adjust_config(self):
|
63 |
+
new_temp = self.adjust_temp + self.temp_increment
|
64 |
+
self.json_report.set_text(text_main=f'Incrementing temperature from {self.adjust_temp} to {new_temp}')
|
65 |
+
self.logger.info(f'Incrementing temperature from {self.adjust_temp} to {new_temp}')
|
66 |
+
self.adjust_temp += self.temp_increment
|
67 |
+
self.config['temperature'] = self.adjust_temp
|
68 |
+
|
69 |
+
def _reset_config(self):
|
70 |
+
self.json_report.set_text(text_main=f'Resetting temperature from {self.adjust_temp} to {self.starting_temp}')
|
71 |
+
self.logger.info(f'Resetting temperature from {self.adjust_temp} to {self.starting_temp}')
|
72 |
+
self.adjust_temp = self.starting_temp
|
73 |
+
self.config['temperature'] = self.starting_temp
|
74 |
+
|
75 |
+
def _build_model_chain_parser(self):
|
76 |
+
# Instantiate the LLM class for Google Gemini
|
77 |
+
self.llm_model = ChatGoogleGenerativeAI(model='gemini-pro',
|
78 |
+
max_output_tokens=self.config.get('max_output_tokens'),
|
79 |
+
top_p=self.config.get('top_p'))
|
80 |
+
# Set up the retry parser with the runnable
|
81 |
+
self.retry_parser = RetryWithErrorOutputParser.from_llm(parser=self.parser, llm=self.llm_model, max_retries=self.MAX_RETRIES)
|
82 |
+
# Prepare the chain
|
83 |
+
self.chain = self.prompt | self.call_google_gemini
|
84 |
+
|
85 |
+
# Define a function to format the input for Google Gemini call
|
86 |
+
def call_google_gemini(self, prompt_text):
|
87 |
+
model = GenerativeModel(self.model_name)
|
88 |
+
response = model.generate_content(prompt_text.text,
|
89 |
+
generation_config=self.config,
|
90 |
+
safety_settings=self.safety_settings)
|
91 |
+
return response.text
|
92 |
+
|
93 |
+
def call_llm_api_GoogleGemini(self, prompt_template, json_report):
|
94 |
+
self.json_report = json_report
|
95 |
+
self.json_report.set_text(text_main=f'Sending request to {self.model_name}')
|
96 |
+
self.monitor.start_monitoring_usage()
|
97 |
+
nt_in = 0
|
98 |
+
nt_out = 0
|
99 |
+
|
100 |
+
ind = 0
|
101 |
+
while ind < self.MAX_RETRIES:
|
102 |
+
ind += 1
|
103 |
+
try:
|
104 |
+
model_kwargs = {"temperature": self.adjust_temp}
|
105 |
+
# Invoke the chain to generate prompt text
|
106 |
+
response = self.chain.invoke({"query": prompt_template, "model_kwargs": model_kwargs})
|
107 |
+
|
108 |
+
# Use retry_parser to parse the response with retry logic
|
109 |
+
output = self.retry_parser.parse_with_prompt(response, prompt_value=prompt_template)
|
110 |
+
|
111 |
+
if output is None:
|
112 |
+
self.logger.error(f'[Attempt {ind}] Failed to extract JSON from:\n{response}')
|
113 |
+
self._adjust_config()
|
114 |
+
else:
|
115 |
+
nt_in = count_tokens(prompt_template, self.VENDOR, self.TOKENIZER_NAME)
|
116 |
+
nt_out = count_tokens(response, self.VENDOR, self.TOKENIZER_NAME)
|
117 |
+
|
118 |
+
output = validate_and_align_JSON_keys_with_template(output, self.JSON_dict_structure)
|
119 |
+
if output is None:
|
120 |
+
self.logger.error(f'[Attempt {ind}] Failed to extract JSON from:\n{response}')
|
121 |
+
self._adjust_config()
|
122 |
+
else:
|
123 |
+
json_report.set_text(text_main=f'Working on WFO and Geolocation')
|
124 |
+
output, WFO_record = validate_taxonomy_WFO(output, replace_if_success_wfo=False) ###################################### make this configurable
|
125 |
+
output, GEO_record = validate_coordinates_here(output, replace_if_success_geo=False) ###################################### make this configurable
|
126 |
+
|
127 |
+
self.logger.info(f"Formatted JSON: {output}")
|
128 |
+
|
129 |
+
self.monitor.stop_monitoring_report_usage()
|
130 |
+
|
131 |
+
if self.adjust_temp != self.starting_temp:
|
132 |
+
self._reset_config()
|
133 |
+
|
134 |
+
json_report.set_text(text_main=f'LLM call successful')
|
135 |
+
return output, nt_in, nt_out, WFO_record, GEO_record
|
136 |
+
|
137 |
+
except Exception as e:
|
138 |
+
self.logger.error(f'{e}')
|
139 |
+
|
140 |
+
self._adjust_config()
|
141 |
+
time.sleep(self.RETRY_DELAY)
|
142 |
+
|
143 |
+
self.logger.info(f"Failed to extract valid JSON after [{ind}] attempts")
|
144 |
+
self.json_report.set_text(text_main=f'Failed to extract valid JSON after [{ind}] attempts')
|
145 |
+
|
146 |
+
self.monitor.stop_monitoring_report_usage()
|
147 |
+
self._reset_config()
|
148 |
+
|
149 |
+
json_report.set_text(text_main=f'LLM call failed')
|
150 |
+
return None, nt_in, nt_out, None, None
|
151 |
+
|
152 |
+
|
vouchervision/LLM_GooglePalm2.py
ADDED
@@ -0,0 +1,162 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os, time
|
2 |
+
import vertexai
|
3 |
+
from vertexai.language_models import TextGenerationModel
|
4 |
+
from vertexai.generative_models._generative_models import HarmCategory, HarmBlockThreshold
|
5 |
+
from langchain.output_parsers import RetryWithErrorOutputParser
|
6 |
+
from langchain.schema import HumanMessage
|
7 |
+
from langchain.prompts import PromptTemplate
|
8 |
+
from langchain_core.output_parsers import JsonOutputParser
|
9 |
+
from langchain_google_genai import ChatGoogleGenerativeAI
|
10 |
+
|
11 |
+
from vouchervision.utils_LLM import SystemLoadMonitor, count_tokens
|
12 |
+
from vouchervision.utils_LLM_JSON_validation import validate_and_align_JSON_keys_with_template
|
13 |
+
from vouchervision.utils_taxonomy_WFO import validate_taxonomy_WFO
|
14 |
+
from vouchervision.utils_geolocate_HERE import validate_coordinates_here
|
15 |
+
|
16 |
+
#https://cloud.google.com/vertex-ai/docs/python-sdk/use-vertex-ai-python-sdk
|
17 |
+
#pip install --upgrade google-cloud-aiplatform
|
18 |
+
# from google.cloud import aiplatform
|
19 |
+
|
20 |
+
#### have to authenticate gcloud
|
21 |
+
# gcloud auth login
|
22 |
+
# gcloud config set project XXXXXXXXX
|
23 |
+
# https://cloud.google.com/docs/authentication
|
24 |
+
|
25 |
+
class GooglePalm2Handler:
|
26 |
+
|
27 |
+
RETRY_DELAY = 10 # Wait 10 seconds before retrying
|
28 |
+
MAX_RETRIES = 3 # Maximum number of retries
|
29 |
+
TOKENIZER_NAME = 'gpt-4'
|
30 |
+
VENDOR = 'google'
|
31 |
+
STARTING_TEMP = 0.5
|
32 |
+
|
33 |
+
def __init__(self, logger, model_name, JSON_dict_structure):
|
34 |
+
self.logger = logger
|
35 |
+
self.model_name = model_name
|
36 |
+
self.JSON_dict_structure = JSON_dict_structure
|
37 |
+
|
38 |
+
self.starting_temp = float(self.STARTING_TEMP)
|
39 |
+
self.temp_increment = float(0.2)
|
40 |
+
self.adjust_temp = self.starting_temp
|
41 |
+
|
42 |
+
self.monitor = SystemLoadMonitor(logger)
|
43 |
+
|
44 |
+
self.parser = JsonOutputParser()
|
45 |
+
|
46 |
+
# Define the prompt template
|
47 |
+
self.prompt = PromptTemplate(
|
48 |
+
template="Answer the user query.\n{format_instructions}\n{query}\n",
|
49 |
+
input_variables=["query"],
|
50 |
+
partial_variables={"format_instructions": self.parser.get_format_instructions()},
|
51 |
+
)
|
52 |
+
self._set_config()
|
53 |
+
|
54 |
+
def _set_config(self):
|
55 |
+
# vertexai.init(project=os.environ['PALM_PROJECT_ID'], location=os.environ['PALM_LOCATION'])
|
56 |
+
self.config = {
|
57 |
+
"max_output_tokens": 1024,
|
58 |
+
"temperature": self.starting_temp,
|
59 |
+
"top_p": 1
|
60 |
+
}
|
61 |
+
self.safety_settings = {
|
62 |
+
HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE,
|
63 |
+
HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_NONE,
|
64 |
+
HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_NONE,
|
65 |
+
HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_NONE,
|
66 |
+
}
|
67 |
+
self._build_model_chain_parser()
|
68 |
+
|
69 |
+
def _adjust_config(self):
|
70 |
+
new_temp = self.adjust_temp + self.temp_increment
|
71 |
+
self.json_report.set_text(text_main=f'Incrementing temperature from {self.adjust_temp} to {new_temp}')
|
72 |
+
self.logger.info(f'Incrementing temperature from {self.adjust_temp} to {new_temp}')
|
73 |
+
self.adjust_temp += self.temp_increment
|
74 |
+
self.config['temperature'] = self.adjust_temp
|
75 |
+
|
76 |
+
def _reset_config(self):
|
77 |
+
self.json_report.set_text(text_main=f'Resetting temperature from {self.adjust_temp} to {self.starting_temp}')
|
78 |
+
self.logger.info(f'Resetting temperature from {self.adjust_temp} to {self.starting_temp}')
|
79 |
+
self.adjust_temp = self.starting_temp
|
80 |
+
self.config['temperature'] = self.starting_temp
|
81 |
+
|
82 |
+
def _build_model_chain_parser(self):
|
83 |
+
# Instantiate the parser and the retry parser
|
84 |
+
self.llm_model = ChatGoogleGenerativeAI(model=self.model_name,
|
85 |
+
max_output_tokens=self.config.get('max_output_tokens'),
|
86 |
+
top_p=self.config.get('top_p'))
|
87 |
+
|
88 |
+
self.retry_parser = RetryWithErrorOutputParser.from_llm(
|
89 |
+
parser=self.parser,
|
90 |
+
llm=self.llm_model,
|
91 |
+
max_retries=self.MAX_RETRIES)
|
92 |
+
# Prepare the chain
|
93 |
+
self.chain = self.prompt | self.call_google_palm2
|
94 |
+
|
95 |
+
# Define a function to format the input for Google PaLM call
|
96 |
+
def call_google_palm2(self, prompt_text):
|
97 |
+
model = TextGenerationModel.from_pretrained(self.model_name)
|
98 |
+
response = model.predict(prompt_text.text,
|
99 |
+
max_output_tokens=self.config.get('max_output_tokens'),
|
100 |
+
temperature=self.config.get('temperature'),
|
101 |
+
top_p=self.config.get('top_p'))
|
102 |
+
return response.text
|
103 |
+
|
104 |
+
|
105 |
+
def call_llm_api_GooglePalm2(self, prompt_template, json_report):
|
106 |
+
self.json_report = json_report
|
107 |
+
self.json_report.set_text(text_main=f'Sending request to {self.model_name}')
|
108 |
+
self.monitor.start_monitoring_usage()
|
109 |
+
nt_in = 0
|
110 |
+
nt_out = 0
|
111 |
+
|
112 |
+
ind = 0
|
113 |
+
while ind < self.MAX_RETRIES:
|
114 |
+
ind += 1
|
115 |
+
try:
|
116 |
+
model_kwargs = {"temperature": self.adjust_temp}
|
117 |
+
# Invoke the chain to generate prompt text
|
118 |
+
response = self.chain.invoke({"query": prompt_template, "model_kwargs": model_kwargs})
|
119 |
+
|
120 |
+
# Use retry_parser to parse the response with retry logic
|
121 |
+
output = self.retry_parser.parse_with_prompt(response, prompt_value=prompt_template)
|
122 |
+
|
123 |
+
if output is None:
|
124 |
+
self.logger.error(f'[Attempt {ind}] Failed to extract JSON from:\n{response}')
|
125 |
+
self._adjust_config()
|
126 |
+
else:
|
127 |
+
nt_in = count_tokens(prompt_template, self.VENDOR, self.TOKENIZER_NAME)
|
128 |
+
nt_out = count_tokens(response, self.VENDOR, self.TOKENIZER_NAME)
|
129 |
+
|
130 |
+
output = validate_and_align_JSON_keys_with_template(output, self.JSON_dict_structure)
|
131 |
+
if output is None:
|
132 |
+
self.logger.error(f'[Attempt {ind}] Failed to extract JSON from:\n{response}')
|
133 |
+
self._adjust_config()
|
134 |
+
else:
|
135 |
+
json_report.set_text(text_main=f'Working on WFO and Geolocation')
|
136 |
+
output, WFO_record = validate_taxonomy_WFO(output, replace_if_success_wfo=False) ###################################### make this configurable
|
137 |
+
output, GEO_record = validate_coordinates_here(output, replace_if_success_geo=False) ###################################### make this configurable
|
138 |
+
|
139 |
+
self.logger.info(f"Formatted JSON: {output}")
|
140 |
+
|
141 |
+
self.monitor.stop_monitoring_report_usage()
|
142 |
+
|
143 |
+
if self.adjust_temp != self.starting_temp:
|
144 |
+
self._reset_config()
|
145 |
+
|
146 |
+
json_report.set_text(text_main=f'LLM call successful')
|
147 |
+
return output, nt_in, nt_out, WFO_record, GEO_record
|
148 |
+
|
149 |
+
except Exception as e:
|
150 |
+
self.logger.error(f'{e}')
|
151 |
+
|
152 |
+
self._adjust_config()
|
153 |
+
time.sleep(self.RETRY_DELAY)
|
154 |
+
|
155 |
+
self.logger.info(f"Failed to extract valid JSON after [{ind}] attempts")
|
156 |
+
self.json_report.set_text(text_main=f'Failed to extract valid JSON after [{ind}] attempts')
|
157 |
+
|
158 |
+
self.monitor.stop_monitoring_report_usage()
|
159 |
+
self._reset_config()
|
160 |
+
|
161 |
+
json_report.set_text(text_main=f'LLM call failed')
|
162 |
+
return None, nt_in, nt_out, None, None
|
vouchervision/LLM_MistralAI.py
ADDED
@@ -0,0 +1,139 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os, time, random, torch
|
2 |
+
from langchain_mistralai.chat_models import ChatMistralAI
|
3 |
+
from langchain.output_parsers import RetryWithErrorOutputParser
|
4 |
+
from langchain.prompts import PromptTemplate
|
5 |
+
from langchain_core.output_parsers import JsonOutputParser
|
6 |
+
|
7 |
+
from vouchervision.utils_LLM import SystemLoadMonitor, count_tokens
|
8 |
+
from vouchervision.utils_LLM_JSON_validation import validate_and_align_JSON_keys_with_template
|
9 |
+
from vouchervision.utils_taxonomy_WFO import validate_taxonomy_WFO
|
10 |
+
from vouchervision.utils_geolocate_HERE import validate_coordinates_here
|
11 |
+
|
12 |
+
|
13 |
+
class MistralHandler:
|
14 |
+
RETRY_DELAY = 2 # Wait 10 seconds before retrying
|
15 |
+
MAX_RETRIES = 5 # Maximum number of retries
|
16 |
+
STARTING_TEMP = 0.1
|
17 |
+
TOKENIZER_NAME = None
|
18 |
+
VENDOR = 'mistral'
|
19 |
+
RANDOM_SEED = 2023
|
20 |
+
|
21 |
+
def __init__(self, logger, model_name, JSON_dict_structure):
|
22 |
+
self.logger = logger
|
23 |
+
self.monitor = SystemLoadMonitor(logger)
|
24 |
+
self.has_GPU = torch.cuda.is_available()
|
25 |
+
self.model_name = model_name
|
26 |
+
self.JSON_dict_structure = JSON_dict_structure
|
27 |
+
self.starting_temp = float(self.STARTING_TEMP)
|
28 |
+
self.temp_increment = float(0.2)
|
29 |
+
self.adjust_temp = self.starting_temp
|
30 |
+
|
31 |
+
# Set up a parser
|
32 |
+
self.parser = JsonOutputParser()
|
33 |
+
|
34 |
+
# Define the prompt template
|
35 |
+
self.prompt = PromptTemplate(
|
36 |
+
template="Answer the user query.\n{format_instructions}\n{query}\n",
|
37 |
+
input_variables=["query"],
|
38 |
+
partial_variables={"format_instructions": self.parser.get_format_instructions()},
|
39 |
+
)
|
40 |
+
|
41 |
+
self._set_config()
|
42 |
+
|
43 |
+
def _set_config(self):
|
44 |
+
self.config = {'max_tokens': 1024,
|
45 |
+
'temperature': self.starting_temp,
|
46 |
+
'random_seed': self.RANDOM_SEED,
|
47 |
+
'safe_mode': False,
|
48 |
+
'top_p': 1,
|
49 |
+
}
|
50 |
+
self._build_model_chain_parser()
|
51 |
+
|
52 |
+
|
53 |
+
def _adjust_config(self):
|
54 |
+
new_temp = self.adjust_temp + self.temp_increment
|
55 |
+
self.config['random_seed'] = random.randint(1, 1000)
|
56 |
+
self.json_report.set_text(text_main=f'Incrementing temperature from {self.adjust_temp} to {new_temp} and random_seed to {self.config.get("random_seed")}')
|
57 |
+
self.logger.info(f'Incrementing temperature from {self.adjust_temp} to {new_temp} and random_seed to {self.config.get("random_seed")}')
|
58 |
+
self.adjust_temp += self.temp_increment
|
59 |
+
self.config['temperature'] = self.adjust_temp
|
60 |
+
|
61 |
+
def _reset_config(self):
|
62 |
+
self.json_report.set_text(text_main=f'Resetting temperature from {self.adjust_temp} to {self.starting_temp} and random_seed to {self.RANDOM_SEED}')
|
63 |
+
self.logger.info(f'Incrementing temperature from {self.adjust_temp} to {self.starting_temp} and random_seed to {self.RANDOM_SEED}')
|
64 |
+
self.adjust_temp = self.starting_temp
|
65 |
+
self.config['temperature'] = self.starting_temp
|
66 |
+
self.config['random_seed'] = self.RANDOM_SEED
|
67 |
+
|
68 |
+
def _build_model_chain_parser(self):
|
69 |
+
# Initialize MistralAI
|
70 |
+
self.llm_model = ChatMistralAI(mistral_api_key=os.environ.get("MISTRAL_API_KEY"),
|
71 |
+
model=self.model_name,
|
72 |
+
max_tokens=self.config.get('max_tokens'),
|
73 |
+
safe_mode=self.config.get('safe_mode'),
|
74 |
+
top_p=self.config.get('top_p'))
|
75 |
+
|
76 |
+
# Set up the retry parser with the runnable
|
77 |
+
self.retry_parser = RetryWithErrorOutputParser.from_llm(parser=self.parser, llm=self.llm_model, max_retries=self.MAX_RETRIES)
|
78 |
+
|
79 |
+
self.chain = self.prompt | self.llm_model
|
80 |
+
|
81 |
+
def call_llm_api_MistralAI(self, prompt_template, json_report):
|
82 |
+
self.json_report = json_report
|
83 |
+
self.json_report.set_text(text_main=f'Sending request to {self.model_name}')
|
84 |
+
self.monitor.start_monitoring_usage()
|
85 |
+
nt_in = 0
|
86 |
+
nt_out = 0
|
87 |
+
|
88 |
+
ind = 0
|
89 |
+
while ind < self.MAX_RETRIES:
|
90 |
+
ind += 1
|
91 |
+
try:
|
92 |
+
model_kwargs = {"temperature": self.adjust_temp, "random_seed": self.config.get("random_seed")}
|
93 |
+
|
94 |
+
# Invoke the chain to generate prompt text
|
95 |
+
response = self.chain.invoke({"query": prompt_template, "model_kwargs": model_kwargs})
|
96 |
+
|
97 |
+
# Use retry_parser to parse the response with retry logic
|
98 |
+
output = self.retry_parser.parse_with_prompt(response.content, prompt_value=prompt_template)
|
99 |
+
|
100 |
+
if output is None:
|
101 |
+
self.logger.error(f'[Attempt {ind}] Failed to extract JSON from:\n{response}')
|
102 |
+
self._adjust_config()
|
103 |
+
else:
|
104 |
+
nt_in = count_tokens(prompt_template, self.VENDOR, self.TOKENIZER_NAME)
|
105 |
+
nt_out = count_tokens(response.content, self.VENDOR, self.TOKENIZER_NAME)
|
106 |
+
|
107 |
+
output = validate_and_align_JSON_keys_with_template(output, self.JSON_dict_structure)
|
108 |
+
if output is None:
|
109 |
+
self.logger.error(f'[Attempt {ind}] Failed to extract JSON from:\n{response}')
|
110 |
+
self._adjust_config()
|
111 |
+
else:
|
112 |
+
json_report.set_text(text_main=f'Working on WFO and Geolocation')
|
113 |
+
output, WFO_record = validate_taxonomy_WFO(output, replace_if_success_wfo=False) ###################################### make this configurable
|
114 |
+
output, GEO_record = validate_coordinates_here(output, replace_if_success_geo=False) ###################################### make this configurable
|
115 |
+
|
116 |
+
self.logger.info(f"Formatted JSON: {output}")
|
117 |
+
|
118 |
+
self.monitor.stop_monitoring_report_usage()
|
119 |
+
|
120 |
+
if self.adjust_temp != self.starting_temp:
|
121 |
+
self._reset_config()
|
122 |
+
|
123 |
+
json_report.set_text(text_main=f'LLM call successful')
|
124 |
+
return output, nt_in, nt_out, WFO_record, GEO_record
|
125 |
+
|
126 |
+
except Exception as e:
|
127 |
+
self.logger.error(f'{e}')
|
128 |
+
|
129 |
+
self._adjust_config()
|
130 |
+
time.sleep(self.RETRY_DELAY)
|
131 |
+
|
132 |
+
self.logger.info(f"Failed to extract valid JSON after [{ind}] attempts")
|
133 |
+
self.json_report.set_text(text_main=f'Failed to extract valid JSON after [{ind}] attempts')
|
134 |
+
|
135 |
+
self.monitor.stop_monitoring_report_usage()
|
136 |
+
self._reset_config()
|
137 |
+
json_report.set_text(text_main=f'LLM call failed')
|
138 |
+
|
139 |
+
return None, nt_in, nt_out, None, None
|
vouchervision/LLM_OpenAI.py
ADDED
@@ -0,0 +1,160 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import time, torch
|
2 |
+
from langchain.prompts import PromptTemplate
|
3 |
+
from langchain_openai import ChatOpenAI, OpenAI
|
4 |
+
from langchain.schema import HumanMessage
|
5 |
+
from langchain_core.output_parsers import JsonOutputParser
|
6 |
+
from langchain.output_parsers import RetryWithErrorOutputParser
|
7 |
+
|
8 |
+
from vouchervision.utils_LLM import SystemLoadMonitor, count_tokens
|
9 |
+
from vouchervision.utils_LLM_JSON_validation import validate_and_align_JSON_keys_with_template
|
10 |
+
from vouchervision.utils_taxonomy_WFO import validate_taxonomy_WFO
|
11 |
+
from vouchervision.utils_geolocate_HERE import validate_coordinates_here
|
12 |
+
|
13 |
+
class OpenAIHandler:
|
14 |
+
RETRY_DELAY = 10 # Wait 10 seconds before retrying
|
15 |
+
MAX_RETRIES = 3 # Maximum number of retries
|
16 |
+
STARTING_TEMP = 0.5
|
17 |
+
TOKENIZER_NAME = 'gpt-4'
|
18 |
+
VENDOR = 'openai'
|
19 |
+
|
20 |
+
def __init__(self, logger, model_name, JSON_dict_structure, is_azure, llm_object):
|
21 |
+
self.logger = logger
|
22 |
+
self.model_name = model_name
|
23 |
+
self.JSON_dict_structure = JSON_dict_structure
|
24 |
+
self.is_azure = is_azure
|
25 |
+
self.llm_object = llm_object
|
26 |
+
self.name_parts = self.model_name.split('-')
|
27 |
+
|
28 |
+
self.monitor = SystemLoadMonitor(logger)
|
29 |
+
self.has_GPU = torch.cuda.is_available()
|
30 |
+
|
31 |
+
self.starting_temp = float(self.STARTING_TEMP)
|
32 |
+
self.temp_increment = float(0.2)
|
33 |
+
self.adjust_temp = self.starting_temp
|
34 |
+
|
35 |
+
# Set up a parser
|
36 |
+
self.parser = JsonOutputParser()
|
37 |
+
|
38 |
+
self.prompt = PromptTemplate(
|
39 |
+
template="Answer the user query.\n{format_instructions}\n{query}\n",
|
40 |
+
input_variables=["query"],
|
41 |
+
partial_variables={"format_instructions": self.parser.get_format_instructions()},
|
42 |
+
)
|
43 |
+
self._set_config()
|
44 |
+
|
45 |
+
def _set_config(self):
|
46 |
+
self.config = {'max_new_tokens': 1024,
|
47 |
+
'temperature': self.starting_temp,
|
48 |
+
'random_seed': 2023,
|
49 |
+
'top_p': 1,
|
50 |
+
}
|
51 |
+
# Adjusting the LLM settings based on whether Azure is used
|
52 |
+
if self.is_azure:
|
53 |
+
self.llm_object.deployment_name = self.model_name
|
54 |
+
self.llm_object.model_name = self.model_name
|
55 |
+
else:
|
56 |
+
self.llm_object = None
|
57 |
+
self._build_model_chain_parser()
|
58 |
+
|
59 |
+
|
60 |
+
# Define a function to format the input for azure_call
|
61 |
+
def format_input_for_azure(self, prompt_text):
|
62 |
+
msg = HumanMessage(content=prompt_text.text)
|
63 |
+
# self.llm_object.temperature = self.config.get('temperature')
|
64 |
+
return self.llm_object(messages=[msg])
|
65 |
+
|
66 |
+
def _adjust_config(self):
|
67 |
+
new_temp = self.adjust_temp + self.temp_increment
|
68 |
+
self.json_report.set_text(text_main=f'Incrementing temperature from {self.adjust_temp} to {new_temp}')
|
69 |
+
self.logger.info(f'Incrementing temperature from {self.adjust_temp} to {new_temp}')
|
70 |
+
self.adjust_temp += self.temp_increment
|
71 |
+
self.config['temperature'] = self.adjust_temp
|
72 |
+
|
73 |
+
def _reset_config(self):
|
74 |
+
self.json_report.set_text(text_main=f'Resetting temperature from {self.adjust_temp} to {self.starting_temp}')
|
75 |
+
self.logger.info(f'Resetting temperature from {self.adjust_temp} to {self.starting_temp}')
|
76 |
+
self.adjust_temp = self.starting_temp
|
77 |
+
self.config['temperature'] = self.starting_temp
|
78 |
+
|
79 |
+
def _build_model_chain_parser(self):
|
80 |
+
if not self.is_azure and ('instruct' in self.name_parts):
|
81 |
+
# Set up the retry parser with 3 retries
|
82 |
+
self.retry_parser = RetryWithErrorOutputParser.from_llm(
|
83 |
+
# parser=self.parser, llm=self.llm_object if self.is_azure else OpenAI(temperature=self.config.get('temperature'), model=self.model_name), max_retries=self.MAX_RETRIES
|
84 |
+
parser=self.parser, llm=self.llm_object if self.is_azure else OpenAI(model=self.model_name), max_retries=self.MAX_RETRIES
|
85 |
+
)
|
86 |
+
else:
|
87 |
+
# Set up the retry parser with 3 retries
|
88 |
+
self.retry_parser = RetryWithErrorOutputParser.from_llm(
|
89 |
+
# parser=self.parser, llm=self.llm_object if self.is_azure else ChatOpenAI(temperature=self.config.get('temperature'), model=self.model_name), max_retries=self.MAX_RETRIES
|
90 |
+
parser=self.parser, llm=self.llm_object if self.is_azure else ChatOpenAI(model=self.model_name), max_retries=self.MAX_RETRIES
|
91 |
+
)
|
92 |
+
# Prepare the chain
|
93 |
+
if not self.is_azure and ('instruct' in self.name_parts):
|
94 |
+
# self.chain = self.prompt | (self.format_input_for_azure if self.is_azure else OpenAI(temperature=self.config.get('temperature'), model=self.model_name))
|
95 |
+
self.chain = self.prompt | (self.format_input_for_azure if self.is_azure else OpenAI(model=self.model_name))
|
96 |
+
else:
|
97 |
+
# self.chain = self.prompt | (self.format_input_for_azure if self.is_azure else ChatOpenAI(temperature=self.config.get('temperature'), model=self.model_name))
|
98 |
+
self.chain = self.prompt | (self.format_input_for_azure if self.is_azure else ChatOpenAI(model=self.model_name))
|
99 |
+
|
100 |
+
|
101 |
+
def call_llm_api_OpenAI(self, prompt_template, json_report):
|
102 |
+
self.json_report = json_report
|
103 |
+
self.json_report.set_text(text_main=f'Sending request to {self.model_name}')
|
104 |
+
self.monitor.start_monitoring_usage()
|
105 |
+
nt_in = 0
|
106 |
+
nt_out = 0
|
107 |
+
|
108 |
+
ind = 0
|
109 |
+
while ind < self.MAX_RETRIES:
|
110 |
+
ind += 1
|
111 |
+
try:
|
112 |
+
model_kwargs = {"temperature": self.adjust_temp}
|
113 |
+
# Invoke the chain to generate prompt text
|
114 |
+
response = self.chain.invoke({"query": prompt_template, "model_kwargs": model_kwargs})
|
115 |
+
|
116 |
+
response_text = response.content if not isinstance(response, str) else response
|
117 |
+
|
118 |
+
# Use retry_parser to parse the response with retry logic
|
119 |
+
output = self.retry_parser.parse_with_prompt(response_text, prompt_value=prompt_template)
|
120 |
+
|
121 |
+
if output is None:
|
122 |
+
self.logger.error(f'[Attempt {ind}] Failed to extract JSON from:\n{response_text}')
|
123 |
+
self._adjust_config()
|
124 |
+
else:
|
125 |
+
nt_in = count_tokens(prompt_template, self.VENDOR, self.TOKENIZER_NAME)
|
126 |
+
nt_out = count_tokens(response_text, self.VENDOR, self.TOKENIZER_NAME)
|
127 |
+
|
128 |
+
output = validate_and_align_JSON_keys_with_template(output, self.JSON_dict_structure)
|
129 |
+
if output is None:
|
130 |
+
self.logger.error(f'[Attempt {ind}] Failed to extract JSON from:\n{response_text}')
|
131 |
+
self._adjust_config()
|
132 |
+
else:
|
133 |
+
json_report.set_text(text_main=f'Working on WFO and Geolocation')
|
134 |
+
|
135 |
+
output, WFO_record = validate_taxonomy_WFO(output, replace_if_success_wfo=False) ###################################### make this configurable
|
136 |
+
output, GEO_record = validate_coordinates_here(output, replace_if_success_geo=False) ###################################### make this configurable
|
137 |
+
|
138 |
+
self.logger.info(f"Formatted JSON: {output}")
|
139 |
+
|
140 |
+
self.monitor.stop_monitoring_report_usage()
|
141 |
+
|
142 |
+
if self.adjust_temp != self.starting_temp:
|
143 |
+
self._reset_config()
|
144 |
+
json_report.set_text(text_main=f'LLM call successful')
|
145 |
+
return output, nt_in, nt_out, WFO_record, GEO_record
|
146 |
+
|
147 |
+
except Exception as e:
|
148 |
+
self.logger.error(f'{e}')
|
149 |
+
|
150 |
+
self._adjust_config()
|
151 |
+
time.sleep(self.RETRY_DELAY)
|
152 |
+
|
153 |
+
self.logger.info(f"Failed to extract valid JSON after [{ind}] attempts")
|
154 |
+
self.json_report.set_text(text_main=f'Failed to extract valid JSON after [{ind}] attempts')
|
155 |
+
|
156 |
+
self.monitor.stop_monitoring_report_usage()
|
157 |
+
self._reset_config()
|
158 |
+
|
159 |
+
json_report.set_text(text_main=f'LLM call failed')
|
160 |
+
return None, nt_in, nt_out, None, None
|
vouchervision/LLM_PaLM.py
DELETED
@@ -1,209 +0,0 @@
|
|
1 |
-
import os
|
2 |
-
import sys
|
3 |
-
import inspect
|
4 |
-
import json
|
5 |
-
from json import JSONDecodeError
|
6 |
-
import tiktoken
|
7 |
-
import random
|
8 |
-
import google.generativeai as palm
|
9 |
-
|
10 |
-
currentdir = os.path.dirname(os.path.abspath(
|
11 |
-
inspect.getfile(inspect.currentframe())))
|
12 |
-
parentdir = os.path.dirname(currentdir)
|
13 |
-
sys.path.append(parentdir)
|
14 |
-
|
15 |
-
from prompt_catalog import PromptCatalog
|
16 |
-
from general_utils import num_tokens_from_string
|
17 |
-
|
18 |
-
"""
|
19 |
-
DEPRECATED:
|
20 |
-
Safety setting regularly block a response, so set to 4 to disable
|
21 |
-
|
22 |
-
class HarmBlockThreshold(Enum):
|
23 |
-
HARM_BLOCK_THRESHOLD_UNSPECIFIED = 0
|
24 |
-
BLOCK_LOW_AND_ABOVE = 1
|
25 |
-
BLOCK_MEDIUM_AND_ABOVE = 2
|
26 |
-
BLOCK_ONLY_HIGH = 3
|
27 |
-
BLOCK_NONE = 4
|
28 |
-
"""
|
29 |
-
|
30 |
-
SAFETY_SETTINGS = [
|
31 |
-
{
|
32 |
-
"category": "HARM_CATEGORY_DEROGATORY",
|
33 |
-
"threshold": "BLOCK_NONE",
|
34 |
-
},
|
35 |
-
{
|
36 |
-
"category": "HARM_CATEGORY_TOXICITY",
|
37 |
-
"threshold": "BLOCK_NONE",
|
38 |
-
},
|
39 |
-
{
|
40 |
-
"category": "HARM_CATEGORY_VIOLENCE",
|
41 |
-
"threshold": "BLOCK_NONE",
|
42 |
-
},
|
43 |
-
{
|
44 |
-
"category": "HARM_CATEGORY_SEXUAL",
|
45 |
-
"threshold": "BLOCK_NONE",
|
46 |
-
},
|
47 |
-
{
|
48 |
-
"category": "HARM_CATEGORY_MEDICAL",
|
49 |
-
"threshold": "BLOCK_NONE",
|
50 |
-
},
|
51 |
-
{
|
52 |
-
"category": "HARM_CATEGORY_DANGEROUS",
|
53 |
-
"threshold": "BLOCK_NONE",
|
54 |
-
},
|
55 |
-
]
|
56 |
-
|
57 |
-
PALM_SETTINGS = {
|
58 |
-
'model': 'models/text-bison-001',
|
59 |
-
'temperature': 0,
|
60 |
-
'candidate_count': 1,
|
61 |
-
'top_k': 40,
|
62 |
-
'top_p': 0.95,
|
63 |
-
'max_output_tokens': 8000,
|
64 |
-
'stop_sequences': [],
|
65 |
-
'safety_settings': SAFETY_SETTINGS,
|
66 |
-
}
|
67 |
-
|
68 |
-
PALM_SETTINGS_REDO = {
|
69 |
-
'model': 'models/text-bison-001',
|
70 |
-
'temperature': 0.05,
|
71 |
-
'candidate_count': 1,
|
72 |
-
'top_k': 40,
|
73 |
-
'top_p': 0.95,
|
74 |
-
'max_output_tokens': 8000,
|
75 |
-
'stop_sequences': [],
|
76 |
-
'safety_settings': SAFETY_SETTINGS,
|
77 |
-
}
|
78 |
-
|
79 |
-
def OCR_to_dict_PaLM(logger, OCR, prompt_version, VVE):
|
80 |
-
try:
|
81 |
-
logger.info(f'Length of OCR raw -- {len(OCR)}')
|
82 |
-
except:
|
83 |
-
print(f'Length of OCR raw -- {len(OCR)}')
|
84 |
-
|
85 |
-
# prompt = PROMPT_PaLM_UMICH_skeleton_all_asia(OCR, in_list, out_list) # must provide examples to PaLM differently than for chatGPT, at least 2 examples
|
86 |
-
Prompt = PromptCatalog(OCR)
|
87 |
-
if prompt_version in ['prompt_v2_palm2']:
|
88 |
-
version = 'v2'
|
89 |
-
prompt = Prompt.prompt_v2_palm2(OCR)
|
90 |
-
|
91 |
-
elif prompt_version in ['prompt_v1_palm2',]:
|
92 |
-
version = 'v1'
|
93 |
-
# create input: output: for PaLM
|
94 |
-
# Find a similar example from the domain knowledge
|
95 |
-
domain_knowledge_example = VVE.query_db(OCR, 4)
|
96 |
-
similarity= VVE.get_similarity()
|
97 |
-
domain_knowledge_example_string = json.dumps(domain_knowledge_example)
|
98 |
-
in_list, out_list = create_OCR_analog_for_input(domain_knowledge_example)
|
99 |
-
prompt = Prompt.prompt_v1_palm2(in_list, out_list, OCR)
|
100 |
-
|
101 |
-
elif prompt_version in ['prompt_v1_palm2_noDomainKnowledge',]:
|
102 |
-
version = 'v1'
|
103 |
-
prompt = Prompt.prompt_v1_palm2_noDomainKnowledge(OCR)
|
104 |
-
else:
|
105 |
-
version = 'custom'
|
106 |
-
prompt, n_fields, xlsx_headers = Prompt.prompt_v2_custom(prompt_version, OCR=OCR, is_palm=True)
|
107 |
-
# raise
|
108 |
-
|
109 |
-
nt = num_tokens_from_string(prompt, "cl100k_base")
|
110 |
-
# try:
|
111 |
-
logger.info(f'Prompt token length --- {nt}')
|
112 |
-
# except:
|
113 |
-
# print(f'Prompt token length --- {nt}')
|
114 |
-
|
115 |
-
do_use_SOP = False ########
|
116 |
-
|
117 |
-
if do_use_SOP:
|
118 |
-
'''TODO: Check back later to see if LangChain will support PaLM'''
|
119 |
-
# logger.info(f'Waiting for PaLM API call --- Using StructuredOutputParser')
|
120 |
-
# response = structured_output_parser(OCR, prompt, logger)
|
121 |
-
# return response['Dictionary']
|
122 |
-
pass
|
123 |
-
|
124 |
-
else:
|
125 |
-
# try:
|
126 |
-
logger.info(f'Waiting for PaLM 2 API call')
|
127 |
-
# except:
|
128 |
-
# print(f'Waiting for PaLM 2 API call --- Content')
|
129 |
-
|
130 |
-
# safety_thresh = 4
|
131 |
-
# PaLM_settings = {'model': 'models/text-bison-001','temperature': 0,'candidate_count': 1,'top_k': 40,'top_p': 0.95,'max_output_tokens': 8000,'stop_sequences': [],
|
132 |
-
# 'safety_settings': [{"category":"HARM_CATEGORY_DEROGATORY","threshold":safety_thresh},{"category":"HARM_CATEGORY_TOXICITY","threshold":safety_thresh},{"category":"HARM_CATEGORY_VIOLENCE","threshold":safety_thresh},{"category":"HARM_CATEGORY_SEXUAL","threshold":safety_thresh},{"category":"HARM_CATEGORY_MEDICAL","threshold":safety_thresh},{"category":"HARM_CATEGORY_DANGEROUS","threshold":safety_thresh}],}
|
133 |
-
response = palm.generate_text(prompt=prompt, **PALM_SETTINGS)
|
134 |
-
|
135 |
-
|
136 |
-
if response and response.result:
|
137 |
-
if isinstance(response.result, (str, bytes)):
|
138 |
-
response_valid = check_and_redo_JSON(response, logger, version)
|
139 |
-
else:
|
140 |
-
response_valid = {}
|
141 |
-
else:
|
142 |
-
response_valid = {}
|
143 |
-
|
144 |
-
logger.info(f'Candidate JSON\n{response.result}')
|
145 |
-
return response_valid, nt
|
146 |
-
|
147 |
-
def check_and_redo_JSON(response, logger, version):
|
148 |
-
try:
|
149 |
-
response_valid = json.loads(response.result)
|
150 |
-
logger.info(f'Response --- First call passed')
|
151 |
-
return response_valid
|
152 |
-
except JSONDecodeError:
|
153 |
-
|
154 |
-
try:
|
155 |
-
response_valid = json.loads(response.result.strip('```').replace('json\n', '', 1).replace('json', '', 1))
|
156 |
-
logger.info(f'Response --- Manual removal of ```json succeeded')
|
157 |
-
return response_valid
|
158 |
-
except:
|
159 |
-
logger.info(f'Response --- First call failed. Redo...')
|
160 |
-
Prompt = PromptCatalog()
|
161 |
-
if version == 'v1':
|
162 |
-
prompt_redo = Prompt.prompt_palm_redo_v1(response.result)
|
163 |
-
elif version == 'v2':
|
164 |
-
prompt_redo = Prompt.prompt_palm_redo_v2(response.result)
|
165 |
-
elif version == 'custom':
|
166 |
-
prompt_redo = Prompt.prompt_v2_custom_redo(response.result, is_palm=True)
|
167 |
-
|
168 |
-
|
169 |
-
# prompt_redo = PROMPT_PaLM_Redo(response.result)
|
170 |
-
try:
|
171 |
-
response = palm.generate_text(prompt=prompt_redo, **PALM_SETTINGS)
|
172 |
-
response_valid = json.loads(response.result)
|
173 |
-
logger.info(f'Response --- Second call passed')
|
174 |
-
return response_valid
|
175 |
-
except JSONDecodeError:
|
176 |
-
logger.info(f'Response --- Second call failed. Final redo. Temperature changed to 0.05')
|
177 |
-
try:
|
178 |
-
response = palm.generate_text(prompt=prompt_redo, **PALM_SETTINGS_REDO)
|
179 |
-
response_valid = json.loads(response.result)
|
180 |
-
logger.info(f'Response --- Third call passed')
|
181 |
-
return response_valid
|
182 |
-
except JSONDecodeError:
|
183 |
-
return None
|
184 |
-
|
185 |
-
|
186 |
-
def create_OCR_analog_for_input(domain_knowledge_example):
|
187 |
-
in_list = []
|
188 |
-
out_list = []
|
189 |
-
# Iterate over the domain_knowledge_example (list of dictionaries)
|
190 |
-
for row_dict in domain_knowledge_example:
|
191 |
-
# Convert the dictionary to a JSON string and add it to the out_list
|
192 |
-
domain_knowledge_example_string = json.dumps(row_dict)
|
193 |
-
out_list.append(domain_knowledge_example_string)
|
194 |
-
|
195 |
-
# Create a single string from all values in the row_dict
|
196 |
-
row_text = '||'.join(str(v) for v in row_dict.values())
|
197 |
-
|
198 |
-
# Split the row text by '||', shuffle the parts, and then re-join with a single space
|
199 |
-
parts = row_text.split('||')
|
200 |
-
random.shuffle(parts)
|
201 |
-
shuffled_text = ' '.join(parts)
|
202 |
-
|
203 |
-
# Add the shuffled_text to the in_list
|
204 |
-
in_list.append(shuffled_text)
|
205 |
-
return in_list, out_list
|
206 |
-
|
207 |
-
|
208 |
-
def strip_problematic_chars(s):
|
209 |
-
return ''.join(c for c in s if c.isprintable())
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
vouchervision/LLM_chatGPT_3_5.py
DELETED
@@ -1,427 +0,0 @@
|
|
1 |
-
import openai
|
2 |
-
import os, json, sys, inspect, time, requests
|
3 |
-
from langchain.output_parsers import StructuredOutputParser, ResponseSchema
|
4 |
-
from langchain.prompts import PromptTemplate, ChatPromptTemplate, HumanMessagePromptTemplate
|
5 |
-
from langchain.llms import OpenAI
|
6 |
-
from langchain.chat_models import ChatOpenAI, AzureChatOpenAI
|
7 |
-
from langchain.schema import HumanMessage
|
8 |
-
from general_utils import num_tokens_from_string
|
9 |
-
|
10 |
-
currentdir = os.path.dirname(os.path.abspath(
|
11 |
-
inspect.getfile(inspect.currentframe())))
|
12 |
-
parentdir = os.path.dirname(currentdir)
|
13 |
-
sys.path.append(parentdir)
|
14 |
-
|
15 |
-
from prompts import PROMPT_UMICH_skeleton_all_asia, PROMPT_OCR_Organized, PROMPT_UMICH_skeleton_all_asia_GPT4, PROMPT_OCR_Organized_GPT4, PROMPT_JSON
|
16 |
-
from prompt_catalog import PromptCatalog
|
17 |
-
|
18 |
-
RETRY_DELAY = 61 # Wait 60 seconds before retrying
|
19 |
-
MAX_RETRIES = 5 # Maximum number of retries
|
20 |
-
|
21 |
-
|
22 |
-
def azure_call(model, messages):
|
23 |
-
response = model(messages=messages)
|
24 |
-
return response
|
25 |
-
|
26 |
-
def OCR_to_dict(is_azure, logger, MODEL, prompt, llm, prompt_version):
|
27 |
-
for i in range(MAX_RETRIES):
|
28 |
-
try:
|
29 |
-
do_use_SOP = True
|
30 |
-
|
31 |
-
if do_use_SOP:
|
32 |
-
logger.info(f'Waiting for {MODEL} API call --- Using StructuredOutputParser')
|
33 |
-
response = structured_output_parser(is_azure, MODEL, llm, prompt, logger, prompt_version)
|
34 |
-
if response is None:
|
35 |
-
return None
|
36 |
-
else:
|
37 |
-
return response['Dictionary']
|
38 |
-
|
39 |
-
else:
|
40 |
-
### Direct GPT ###
|
41 |
-
logger.info(f'Waiting for {MODEL} API call')
|
42 |
-
if not is_azure:
|
43 |
-
response = openai.ChatCompletion.create(
|
44 |
-
model=MODEL,
|
45 |
-
temperature = 0,
|
46 |
-
messages=[
|
47 |
-
{"role": "system", "content": "You are a helpful assistant acting as a transcription expert and your job is to transcribe herbarium specimen labels based on OCR data and reformat it to meet Darwin Core Archive Standards into a Python dictionary based on certain rules."},
|
48 |
-
{"role": "user", "content": prompt},
|
49 |
-
],
|
50 |
-
response_format={type: "json_object"},#################################### TODO ###############################################################################
|
51 |
-
max_tokens=4096,
|
52 |
-
)
|
53 |
-
# print the model's response
|
54 |
-
return response.choices[0].message['content']
|
55 |
-
else:
|
56 |
-
msg = HumanMessage(
|
57 |
-
content=prompt
|
58 |
-
)
|
59 |
-
response = azure_call(llm, [msg])
|
60 |
-
return response.content
|
61 |
-
except Exception as e:
|
62 |
-
logger.error(f'{e}')
|
63 |
-
if i < MAX_RETRIES - 1: # No delay needed after the last try
|
64 |
-
time.sleep(RETRY_DELAY)
|
65 |
-
else:
|
66 |
-
raise
|
67 |
-
|
68 |
-
# def OCR_to_dict(logger, MODEL, prompt, OCR, BASE_URL, HEADERS):
|
69 |
-
# for i in range(MAX_RETRIES):
|
70 |
-
# try:
|
71 |
-
# do_use_SOP = False
|
72 |
-
|
73 |
-
# if do_use_SOP:
|
74 |
-
# logger.info(f'Waiting for {MODEL} API call --- Using StructuredOutputParser -- Content')
|
75 |
-
# response = structured_output_parser(MODEL, OCR, prompt, logger)
|
76 |
-
# if response is None:
|
77 |
-
# return None
|
78 |
-
# else:
|
79 |
-
# return response['Dictionary']
|
80 |
-
|
81 |
-
# else:
|
82 |
-
# ### Direct GPT through Azure ###
|
83 |
-
# logger.info(f'Waiting for {MODEL} API call')
|
84 |
-
# response = azure_gpt_request(prompt, BASE_URL, HEADERS, model_name=MODEL)
|
85 |
-
|
86 |
-
# # Handle the response data. Note: You might need to adjust the following line based on the exact response format of the Azure API.
|
87 |
-
# content = response.get("choices", [{}])[0].get("message", {}).get("content", "")
|
88 |
-
# return content
|
89 |
-
# except requests.exceptions.RequestException as e: # Replace openai.error.APIError with requests exception.
|
90 |
-
# # Handle HTTP exceptions. You can adjust this based on the Azure API's error responses.
|
91 |
-
# if e.response.status_code == 502:
|
92 |
-
# logger.info(f' *** 502 error was encountered, wait and try again ***')
|
93 |
-
# if i < MAX_RETRIES - 1:
|
94 |
-
# time.sleep(RETRY_DELAY)
|
95 |
-
# else:
|
96 |
-
# raise
|
97 |
-
|
98 |
-
|
99 |
-
def OCR_to_dict_16k(is_azure, logger, MODEL, prompt, llm, prompt_version):
|
100 |
-
for i in range(MAX_RETRIES):
|
101 |
-
try:
|
102 |
-
fs = FunctionSchema()
|
103 |
-
response = openai.ChatCompletion.create(
|
104 |
-
model=MODEL,
|
105 |
-
temperature = 0,
|
106 |
-
messages=[
|
107 |
-
{"role": "system", "content": "You are a helpful assistant acting as a transcription expert and your job is to transcribe herbarium specimen labels based on OCR data and reformat it to meet Darwin Core Archive Standards into a Python dictionary based on certain rules."},
|
108 |
-
{"role": "user", "content": prompt},
|
109 |
-
],
|
110 |
-
max_tokens=8000,
|
111 |
-
function_call= "none",
|
112 |
-
functions= fs.format_C21_AA_V1()
|
113 |
-
|
114 |
-
)
|
115 |
-
# Try to parse the response into JSON
|
116 |
-
call_failed = False
|
117 |
-
try:
|
118 |
-
response_string = response.choices[0].message['content']
|
119 |
-
except:
|
120 |
-
call_failed = True
|
121 |
-
response_string = prompt
|
122 |
-
|
123 |
-
if not call_failed:
|
124 |
-
try:
|
125 |
-
# Try to parse the response into JSON
|
126 |
-
response_dict = json.loads(response_string)
|
127 |
-
return response_dict['Dictionary']
|
128 |
-
except json.JSONDecodeError:
|
129 |
-
# If the response is not a valid JSON, call the structured_output_parser_for_function_calls_fail function
|
130 |
-
logger.info(f'Invalid JSON response, calling structured_output_parser_for_function_calls_fail function')
|
131 |
-
logger.info(f'Waiting for {MODEL} API call --- Using StructuredOutputParser --- JSON Fixer')
|
132 |
-
response_sop = structured_output_parser_for_function_calls_fail(is_azure, MODEL, response_string, logger, llm, prompt_version, is_helper=False)
|
133 |
-
if response_sop is None:
|
134 |
-
return None
|
135 |
-
else:
|
136 |
-
return response_sop['Dictionary']
|
137 |
-
else:
|
138 |
-
try:
|
139 |
-
logger.info(f'Call Failed. Attempting fallback JSON parse without guidance')
|
140 |
-
logger.info(f'Waiting for {MODEL} API call --- Using StructuredOutputParser --- JSON Fixer')
|
141 |
-
response_sop = structured_output_parser_for_function_calls_fail(is_azure, MODEL, response_string, logger, llm, prompt_version, is_helper=False)
|
142 |
-
if response_sop is None:
|
143 |
-
return None
|
144 |
-
else:
|
145 |
-
return response_sop['Dictionary']
|
146 |
-
except:
|
147 |
-
return None
|
148 |
-
except Exception as e:
|
149 |
-
# if e.status_code == 401: # or you can check the error message
|
150 |
-
logger.info(f' *** 401 error was encountered, wait and try again ***')
|
151 |
-
# If a 401 error was encountered, wait and try again
|
152 |
-
if i < MAX_RETRIES - 1: # No delay needed after the last try
|
153 |
-
time.sleep(RETRY_DELAY)
|
154 |
-
else:
|
155 |
-
# If it was a different error, re-raise it
|
156 |
-
raise
|
157 |
-
|
158 |
-
def structured_output_parser(is_azure, MODEL, llm, prompt_template, logger, prompt_version, is_helper=False):
|
159 |
-
if not is_helper:
|
160 |
-
response_schemas = [
|
161 |
-
ResponseSchema(name="SpeciesName", description="Taxonomic determination, genus_species"),
|
162 |
-
ResponseSchema(name="Dictionary", description='Formatted JSON object'),]#prompt_template),]
|
163 |
-
elif is_helper:
|
164 |
-
response_schemas = [
|
165 |
-
ResponseSchema(name="Dictionary", description='Formatted JSON object'),#prompt_template),
|
166 |
-
ResponseSchema(name="Summary", description="A one sentence summary of the content"),]
|
167 |
-
|
168 |
-
output_parser = StructuredOutputParser.from_response_schemas(response_schemas)
|
169 |
-
|
170 |
-
format_instructions = output_parser.get_format_instructions()
|
171 |
-
|
172 |
-
prompt = ChatPromptTemplate(
|
173 |
-
messages=[
|
174 |
-
HumanMessagePromptTemplate.from_template("Parse the OCR text into the correct structured format.\n{format_instructions}\n{question}")
|
175 |
-
],
|
176 |
-
input_variables=["question"],
|
177 |
-
partial_variables={"format_instructions": format_instructions}
|
178 |
-
)
|
179 |
-
|
180 |
-
# Handle Azure vs OpenAI implementation
|
181 |
-
if is_azure:
|
182 |
-
_input = prompt.format_prompt(question=prompt_template)
|
183 |
-
msg = HumanMessage(content=_input.to_string())
|
184 |
-
output = azure_call(llm, [msg])
|
185 |
-
else:
|
186 |
-
chat_model = ChatOpenAI(temperature=0, model=MODEL)
|
187 |
-
_input = prompt.format_prompt(question=prompt_template)
|
188 |
-
output = chat_model(_input.to_messages())
|
189 |
-
|
190 |
-
# Log token length if running with Gradio
|
191 |
-
try:
|
192 |
-
nt = num_tokens_from_string(_input.to_string(), "cl100k_base")
|
193 |
-
logger.info(f'Prompt token length --- {nt}')
|
194 |
-
except:
|
195 |
-
pass
|
196 |
-
|
197 |
-
# Parse the output
|
198 |
-
try:
|
199 |
-
# Check if output is of type 'ai' and parse accordingly
|
200 |
-
if output.type == 'ai':
|
201 |
-
parsed_content = output.content
|
202 |
-
logger.info(f'Formatted JSON\n{parsed_content}')
|
203 |
-
else:
|
204 |
-
# If not 'ai', log and set parsed_content to None or a default value
|
205 |
-
logger.error('Output type is not "ai". Unable to parse.')
|
206 |
-
return None
|
207 |
-
|
208 |
-
# Clean up the parsed content
|
209 |
-
parsed_content = parsed_content.replace('\n', "").replace('\t', "").replace('|', "")
|
210 |
-
|
211 |
-
# Attempt to parse the cleaned content
|
212 |
-
try:
|
213 |
-
refined_response = output_parser.parse(parsed_content)
|
214 |
-
return refined_response
|
215 |
-
except Exception as parse_error:
|
216 |
-
# Handle parsing errors specifically
|
217 |
-
logger.error(f'Parsing Error: {parse_error}')
|
218 |
-
return structured_output_parser_for_function_calls_fail(is_azure, MODEL, parsed_content, logger, llm, prompt_version, is_helper)
|
219 |
-
|
220 |
-
except Exception as e:
|
221 |
-
# Handle any other exceptions that might occur
|
222 |
-
logger.error(f'Unexpected Error: {e}')
|
223 |
-
return None
|
224 |
-
|
225 |
-
def structured_output_parser_for_function_calls_fail(is_azure, MODEL, failed_response, logger, llm, prompt_version, is_helper=False, try_ind=0):
|
226 |
-
if try_ind == 0:
|
227 |
-
original_failed_response = failed_response
|
228 |
-
if try_ind > 5:
|
229 |
-
return None
|
230 |
-
|
231 |
-
# prompt_redo = PROMPT_JSON('helper' if is_helper else 'dict', failed_response)
|
232 |
-
Prompt = PromptCatalog()
|
233 |
-
if prompt_version in ['prompt_v1_verbose', 'prompt_v1_verbose_noDomainKnowledge']:
|
234 |
-
prompt_redo = Prompt.prompt_gpt_redo_v1(failed_response)
|
235 |
-
elif prompt_version in ['prompt_v2_json_rules']:
|
236 |
-
prompt_redo = Prompt.prompt_gpt_redo_v2(failed_response)
|
237 |
-
else:
|
238 |
-
prompt_redo = Prompt.prompt_v2_custom_redo(failed_response, is_palm=False)
|
239 |
-
|
240 |
-
response_schemas = [
|
241 |
-
ResponseSchema(name="Summary", description="A one sentence summary of the content"),
|
242 |
-
ResponseSchema(name="Dictionary", description='Formatted JSON object')
|
243 |
-
]
|
244 |
-
|
245 |
-
output_parser = StructuredOutputParser.from_response_schemas(response_schemas)
|
246 |
-
format_instructions = output_parser.get_format_instructions()
|
247 |
-
|
248 |
-
prompt = ChatPromptTemplate(
|
249 |
-
messages=[
|
250 |
-
HumanMessagePromptTemplate.from_template("The following text contains JSON formatted text, but there is an error that you need to correct.\n{format_instructions}\n{question}")
|
251 |
-
],
|
252 |
-
input_variables=["question"],
|
253 |
-
partial_variables={"format_instructions": format_instructions}
|
254 |
-
)
|
255 |
-
|
256 |
-
_input = prompt.format_prompt(question=prompt_redo)
|
257 |
-
|
258 |
-
# Log token length if running with Gradio
|
259 |
-
try:
|
260 |
-
nt = num_tokens_from_string(_input.to_string(), "cl100k_base")
|
261 |
-
logger.info(f'Prompt Redo token length --- {nt}')
|
262 |
-
except:
|
263 |
-
pass
|
264 |
-
|
265 |
-
if is_azure:
|
266 |
-
msg = HumanMessage(content=_input.to_string())
|
267 |
-
output = azure_call(llm, [msg])
|
268 |
-
else:
|
269 |
-
chat_model = ChatOpenAI(temperature=0, model=MODEL)
|
270 |
-
output = chat_model(_input.to_messages())
|
271 |
-
|
272 |
-
try:
|
273 |
-
refined_response = output_parser.parse(output.content)
|
274 |
-
except json.decoder.JSONDecodeError as e:
|
275 |
-
try_ind += 1
|
276 |
-
error_message = str(e)
|
277 |
-
redo_content = f'The error messsage is: {error_message}\nThe broken JSON object is: {original_failed_response}' # Use original_failed_response here
|
278 |
-
logger.info(f'[Failed JSON Object]\n{original_failed_response}') # And here
|
279 |
-
refined_response = structured_output_parser_for_function_calls_fail(
|
280 |
-
is_azure, MODEL, redo_content, logger, llm, prompt_version, is_helper, try_ind, original_failed_response
|
281 |
-
)
|
282 |
-
except:
|
283 |
-
try_ind += 1
|
284 |
-
logger.info(f'[Failed JSON Object]\n{original_failed_response}') # And here
|
285 |
-
refined_response = structured_output_parser_for_function_calls_fail(
|
286 |
-
is_azure, MODEL, original_failed_response, logger, llm, prompt_version, is_helper, try_ind, original_failed_response
|
287 |
-
)
|
288 |
-
|
289 |
-
return refined_response
|
290 |
-
|
291 |
-
|
292 |
-
|
293 |
-
|
294 |
-
class FunctionSchema:
|
295 |
-
def __init__(self):
|
296 |
-
pass
|
297 |
-
|
298 |
-
def format_C21_AA_V1(self):
|
299 |
-
return [
|
300 |
-
{
|
301 |
-
"name": "format_C21_AA_V1",
|
302 |
-
"description": "Format the given data into a specific dictionary",
|
303 |
-
"parameters": {
|
304 |
-
"type": "object",
|
305 |
-
"properties": {}, # specify parameters here if your function requires any
|
306 |
-
"required": [] # list of required parameters
|
307 |
-
},
|
308 |
-
"output_type": "json",
|
309 |
-
"output_schema": {
|
310 |
-
"type": "object",
|
311 |
-
"properties": {
|
312 |
-
"Dictionary": {
|
313 |
-
"type": "object",
|
314 |
-
"properties": {
|
315 |
-
"Catalog Number": {"type": "array", "items": {"type": "string"}},
|
316 |
-
"Genus": {"type": "array", "items": {"type": "string"}},
|
317 |
-
"Species": {"type": "array", "items": {"type": "string"}},
|
318 |
-
"subspecies": {"type": "array", "items": {"type": "string"}},
|
319 |
-
"variety": {"type": "array", "items": {"type": "string"}},
|
320 |
-
"forma": {"type": "array", "items": {"type": "string"}},
|
321 |
-
"Country": {"type": "array", "items": {"type": "string"}},
|
322 |
-
"State": {"type": "array", "items": {"type": "string"}},
|
323 |
-
"County": {"type": "array", "items": {"type": "string"}},
|
324 |
-
"Locality Name": {"type": "array", "items": {"type": "string"}},
|
325 |
-
"Min Elevation": {"type": "array", "items": {"type": "string"}},
|
326 |
-
"Max Elevation": {"type": "array", "items": {"type": "string"}},
|
327 |
-
"Elevation Units": {"type": "array", "items": {"type": "string"}},
|
328 |
-
"Verbatim Coordinates": {"type": "array", "items": {"type": "string"}},
|
329 |
-
"Datum": {"type": "array", "items": {"type": "string"}},
|
330 |
-
"Cultivated": {"type": "array", "items": {"type": "string"}},
|
331 |
-
"Habitat": {"type": "array", "items": {"type": "string"}},
|
332 |
-
"Collectors": {"type": "array", "items": {"type": "string"}},
|
333 |
-
"Collector Number": {"type": "array", "items": {"type": "string"}},
|
334 |
-
"Verbatim Date": {"type": "array", "items": {"type": "string"}},
|
335 |
-
"Date": {"type": "array", "items": {"type": "string"}},
|
336 |
-
"End Date": {"type": "array", "items": {"type": "string"}}
|
337 |
-
}
|
338 |
-
},
|
339 |
-
"SpeciesName": {
|
340 |
-
"type": "object",
|
341 |
-
"properties": {
|
342 |
-
"taxonomy": {"type": "array", "items": {"type": "string"}}
|
343 |
-
}
|
344 |
-
}
|
345 |
-
}
|
346 |
-
}
|
347 |
-
}
|
348 |
-
]
|
349 |
-
|
350 |
-
def format_C21_AA_V1_helper(self):
|
351 |
-
return [
|
352 |
-
{
|
353 |
-
"name": "format_C21_AA_V1_helper",
|
354 |
-
"description": "Helper function for format_C21_AA_V1 to further format the given data",
|
355 |
-
"parameters": {
|
356 |
-
"type": "object",
|
357 |
-
"properties": {}, # specify parameters here if your function requires any
|
358 |
-
"required": [] # list of required parameters
|
359 |
-
},
|
360 |
-
"output_type": "json",
|
361 |
-
"output_schema": {
|
362 |
-
"type": "object",
|
363 |
-
"properties": {
|
364 |
-
"Dictionary": {
|
365 |
-
"type": "object",
|
366 |
-
"properties": {
|
367 |
-
"TAXONOMY": {
|
368 |
-
"type": "object",
|
369 |
-
"properties": {
|
370 |
-
"Order": {"type": "array", "items": {"type": "string"}},
|
371 |
-
"Family": {"type": "array", "items": {"type": "string"}},
|
372 |
-
"Genus":{"type": "array", "items": {"type": "string"}},
|
373 |
-
"Species": {"type": "array", "items": {"type": "string"}},
|
374 |
-
"Subspecies": {"type": "array", "items": {"type": "string"}},
|
375 |
-
"Variety": {"type": "array", "items": {"type": "string"}},
|
376 |
-
"Forma": {"type": "array", "items": {"type": "string"}},
|
377 |
-
}
|
378 |
-
},
|
379 |
-
"GEOGRAPHY": {
|
380 |
-
"type": "object",
|
381 |
-
"properties": {
|
382 |
-
"Country": {"type": "array", "items": {"type": "string"}},
|
383 |
-
"State": {"type": "array", "items": {"type": "string"}},
|
384 |
-
"Prefecture": {"type": "array", "items": {"type": "string"}},
|
385 |
-
"Province": {"type": "array", "items": {"type": "string"}},
|
386 |
-
"District": {"type": "array", "items": {"type": "string"}},
|
387 |
-
"County": {"type": "array", "items": {"type": "string"}},
|
388 |
-
"City": {"type": "array", "items": {"type": "string"}},
|
389 |
-
"Administrative Division": {"type": "array", "items": {"type": "string"}},
|
390 |
-
}
|
391 |
-
},
|
392 |
-
"LOCALITY": {
|
393 |
-
"type": "object",
|
394 |
-
"properties": {
|
395 |
-
"Landscape": {"type": "array", "items": {"type": "string"}},
|
396 |
-
"Nearby Places": {"type": "array", "items": {"type": "string"}},
|
397 |
-
}
|
398 |
-
},
|
399 |
-
"COLLECTING": {
|
400 |
-
"type": "object",
|
401 |
-
"properties": {
|
402 |
-
"Collector": {"type": "array", "items": {"type": "string"}},
|
403 |
-
"Collector's Number": {"type": "array", "items": {"type": "string"}},
|
404 |
-
"Verbatim Date": {"type": "array", "items": {"type": "string"}},
|
405 |
-
"Formatted Date": {"type": "array", "items": {"type": "string"}},
|
406 |
-
"Cultivation Status": {"type": "array", "items": {"type": "string"}},
|
407 |
-
"Habitat Description": {"type": "array", "items": {"type": "string"}},
|
408 |
-
}
|
409 |
-
},
|
410 |
-
"MISCELLANEOUS": {
|
411 |
-
"type": "object",
|
412 |
-
"properties": {
|
413 |
-
"Additional Information": {"type": "array", "items": {"type": "string"}},
|
414 |
-
}
|
415 |
-
}
|
416 |
-
}
|
417 |
-
},
|
418 |
-
"Summary": {
|
419 |
-
"type": "object",
|
420 |
-
"properties": {
|
421 |
-
"Content Summary": {"type": "array", "items": {"type": "string"}}
|
422 |
-
}
|
423 |
-
}
|
424 |
-
}
|
425 |
-
}
|
426 |
-
}
|
427 |
-
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
vouchervision/LLM_local_MistralAI.py
ADDED
@@ -0,0 +1,211 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json, torch, transformers, gc
|
2 |
+
from transformers import BitsAndBytesConfig
|
3 |
+
from langchain.output_parsers import RetryWithErrorOutputParser
|
4 |
+
from langchain.prompts import PromptTemplate
|
5 |
+
from langchain_core.output_parsers import JsonOutputParser
|
6 |
+
from huggingface_hub import hf_hub_download
|
7 |
+
from langchain_community.llms.huggingface_pipeline import HuggingFacePipeline
|
8 |
+
|
9 |
+
from vouchervision.utils_LLM import SystemLoadMonitor, count_tokens
|
10 |
+
from vouchervision.utils_LLM_JSON_validation import validate_and_align_JSON_keys_with_template
|
11 |
+
from vouchervision.utils_taxonomy_WFO import validate_taxonomy_WFO
|
12 |
+
from vouchervision.utils_geolocate_HERE import validate_coordinates_here
|
13 |
+
|
14 |
+
'''
|
15 |
+
https://python.langchain.com/docs/integrations/llms/huggingface_pipelines
|
16 |
+
'''
|
17 |
+
|
18 |
+
class LocalMistralHandler:
|
19 |
+
RETRY_DELAY = 2 # Wait 2 seconds before retrying
|
20 |
+
MAX_RETRIES = 5 # Maximum number of retries
|
21 |
+
STARTING_TEMP = 0.1
|
22 |
+
TOKENIZER_NAME = None
|
23 |
+
VENDOR = 'mistral'
|
24 |
+
MAX_GPU_MONITORING_INTERVAL = 2 # seconds
|
25 |
+
|
26 |
+
def __init__(self, logger, model_name, JSON_dict_structure):
|
27 |
+
self.logger = logger
|
28 |
+
self.has_GPU = torch.cuda.is_available()
|
29 |
+
self.monitor = SystemLoadMonitor(logger)
|
30 |
+
|
31 |
+
self.model_name = model_name
|
32 |
+
self.model_id = f"mistralai/{self.model_name}"
|
33 |
+
name_parts = self.model_name.split('-')
|
34 |
+
|
35 |
+
self.model_path = hf_hub_download(repo_id=self.model_id, repo_type="model",filename="config.json")
|
36 |
+
|
37 |
+
|
38 |
+
self.JSON_dict_structure = JSON_dict_structure
|
39 |
+
self.starting_temp = float(self.STARTING_TEMP)
|
40 |
+
self.temp_increment = float(0.2)
|
41 |
+
self.adjust_temp = self.starting_temp
|
42 |
+
|
43 |
+
system_prompt = "You are a helpful AI assistant who answers queries a JSON dictionary as specified by the user."
|
44 |
+
template = """
|
45 |
+
<s>[INST]{}[/INST]</s>
|
46 |
+
|
47 |
+
[INST]{}[/INST]
|
48 |
+
""".format(system_prompt, "{query}")
|
49 |
+
|
50 |
+
# Create a prompt from the template so we can use it with Langchain
|
51 |
+
self.prompt = PromptTemplate(template=template, input_variables=["query"])
|
52 |
+
|
53 |
+
# Set up a parser
|
54 |
+
self.parser = JsonOutputParser()
|
55 |
+
|
56 |
+
self._set_config()
|
57 |
+
|
58 |
+
|
59 |
+
# def _clear_VRAM(self):
|
60 |
+
# # Clear CUDA cache if it's being used
|
61 |
+
# if self.has_GPU:
|
62 |
+
# self.local_model = None
|
63 |
+
# self.local_model_pipeline = None
|
64 |
+
# del self.local_model
|
65 |
+
# del self.local_model_pipeline
|
66 |
+
# gc.collect() # Explicitly invoke garbage collector
|
67 |
+
# torch.cuda.empty_cache()
|
68 |
+
# else:
|
69 |
+
# self.local_model_pipeline = None
|
70 |
+
# self.local_model = None
|
71 |
+
# del self.local_model_pipeline
|
72 |
+
# del self.local_model
|
73 |
+
# gc.collect() # Explicitly invoke garbage collector
|
74 |
+
|
75 |
+
|
76 |
+
def _set_config(self):
|
77 |
+
# self._clear_VRAM()
|
78 |
+
self.config = {'max_new_tokens': 1024,
|
79 |
+
'temperature': self.starting_temp,
|
80 |
+
'seed': 2023,
|
81 |
+
'top_p': 1,
|
82 |
+
'top_k': 40,
|
83 |
+
'do_sample': True,
|
84 |
+
'n_ctx':4096,
|
85 |
+
|
86 |
+
# Activate 4-bit precision base model loading
|
87 |
+
'use_4bit': True,
|
88 |
+
# Compute dtype for 4-bit base models
|
89 |
+
'bnb_4bit_compute_dtype': "float16",
|
90 |
+
# Quantization type (fp4 or nf4)
|
91 |
+
'bnb_4bit_quant_type': "nf4",
|
92 |
+
# Activate nested quantization for 4-bit base models (double quantization)
|
93 |
+
'use_nested_quant': False,
|
94 |
+
}
|
95 |
+
|
96 |
+
compute_dtype = getattr(torch,self.config.get('bnb_4bit_compute_dtype') )
|
97 |
+
|
98 |
+
self.bnb_config = BitsAndBytesConfig(
|
99 |
+
load_in_4bit=self.config.get('use_4bit'),
|
100 |
+
bnb_4bit_quant_type=self.config.get('bnb_4bit_quant_type'),
|
101 |
+
bnb_4bit_compute_dtype=compute_dtype,
|
102 |
+
bnb_4bit_use_double_quant=self.config.get('use_nested_quant'),
|
103 |
+
)
|
104 |
+
|
105 |
+
# Check GPU compatibility with bfloat16
|
106 |
+
if compute_dtype == torch.float16 and self.config.get('use_4bit'):
|
107 |
+
major, _ = torch.cuda.get_device_capability()
|
108 |
+
if major >= 8:
|
109 |
+
# print("=" * 80)
|
110 |
+
# print("Your GPU supports bfloat16: accelerate training with bf16=True")
|
111 |
+
# print("=" * 80)
|
112 |
+
self.b_float_opt = torch.bfloat16
|
113 |
+
|
114 |
+
else:
|
115 |
+
self.b_float_opt = torch.float16
|
116 |
+
self._build_model_chain_parser()
|
117 |
+
|
118 |
+
|
119 |
+
def _adjust_config(self):
|
120 |
+
new_temp = self.adjust_temp + self.temp_increment
|
121 |
+
self.json_report.set_text(text_main=f'Incrementing temperature from {self.adjust_temp} to {new_temp}')
|
122 |
+
self.logger.info(f'Incrementing temperature from {self.adjust_temp} to {new_temp}')
|
123 |
+
self.adjust_temp += self.temp_increment
|
124 |
+
|
125 |
+
|
126 |
+
def _reset_config(self):
|
127 |
+
self.json_report.set_text(text_main=f'Resetting temperature from {self.adjust_temp} to {self.starting_temp}')
|
128 |
+
self.logger.info(f'Resetting temperature from {self.adjust_temp} to {self.starting_temp}')
|
129 |
+
self.adjust_temp = self.starting_temp
|
130 |
+
|
131 |
+
|
132 |
+
def _build_model_chain_parser(self):
|
133 |
+
self.local_model_pipeline = transformers.pipeline("text-generation",
|
134 |
+
model=self.model_id,
|
135 |
+
max_new_tokens=self.config.get('max_new_tokens'),
|
136 |
+
top_k=self.config.get('top_k'),
|
137 |
+
top_p=self.config.get('top_p'),
|
138 |
+
do_sample=self.config.get('do_sample'),
|
139 |
+
model_kwargs={"torch_dtype": self.b_float_opt,
|
140 |
+
"load_in_4bit": True,
|
141 |
+
"quantization_config": self.bnb_config})
|
142 |
+
self.local_model = HuggingFacePipeline(pipeline=self.local_model_pipeline)
|
143 |
+
# Set up the retry parser with the runnable
|
144 |
+
self.retry_parser = RetryWithErrorOutputParser.from_llm(parser=self.parser, llm=self.local_model, max_retries=self.MAX_RETRIES)
|
145 |
+
# Create an llm chain with LLM and prompt
|
146 |
+
self.chain = self.prompt | self.local_model # LCEL
|
147 |
+
|
148 |
+
|
149 |
+
def call_llm_local_MistralAI(self, prompt_template, json_report):
|
150 |
+
self.json_report = json_report
|
151 |
+
self.json_report.set_text(text_main=f'Sending request to {self.model_name}')
|
152 |
+
self.monitor.start_monitoring_usage()
|
153 |
+
|
154 |
+
nt_in = 0
|
155 |
+
nt_out = 0
|
156 |
+
|
157 |
+
ind = 0
|
158 |
+
while ind < self.MAX_RETRIES:
|
159 |
+
ind += 1
|
160 |
+
try:
|
161 |
+
# Dynamically set the temperature for this specific request
|
162 |
+
model_kwargs = {"temperature": self.adjust_temp}
|
163 |
+
|
164 |
+
# Invoke the chain to generate prompt text
|
165 |
+
results = self.chain.invoke({"query": prompt_template, "model_kwargs": model_kwargs})
|
166 |
+
|
167 |
+
# Use retry_parser to parse the response with retry logic
|
168 |
+
output = self.retry_parser.parse_with_prompt(results, prompt_value=prompt_template)
|
169 |
+
|
170 |
+
if output is None:
|
171 |
+
self.logger.error(f'Failed to extract JSON from:\n{results}')
|
172 |
+
self._adjust_config()
|
173 |
+
del results
|
174 |
+
|
175 |
+
else:
|
176 |
+
nt_in = count_tokens(prompt_template, self.VENDOR, self.TOKENIZER_NAME)
|
177 |
+
nt_out = count_tokens(results, self.VENDOR, self.TOKENIZER_NAME)
|
178 |
+
|
179 |
+
output = validate_and_align_JSON_keys_with_template(output, self.JSON_dict_structure)
|
180 |
+
|
181 |
+
if output is None:
|
182 |
+
self.logger.error(f'[Attempt {ind}] Failed to extract JSON from:\n{results}')
|
183 |
+
self._adjust_config()
|
184 |
+
else:
|
185 |
+
json_report.set_text(text_main=f'Working on WFO and Geolocation')
|
186 |
+
output, WFO_record = validate_taxonomy_WFO(output, replace_if_success_wfo=False) # Make configurable if needed
|
187 |
+
output, GEO_record = validate_coordinates_here(output, replace_if_success_geo=False) # Make configurable if needed
|
188 |
+
|
189 |
+
self.logger.info(f"Formatted JSON:\n{json.dumps(output,indent=4)}")
|
190 |
+
|
191 |
+
self.monitor.stop_monitoring_report_usage()
|
192 |
+
|
193 |
+
if self.adjust_temp != self.starting_temp:
|
194 |
+
self._reset_config()
|
195 |
+
|
196 |
+
json_report.set_text(text_main=f'LLM call successful')
|
197 |
+
del results
|
198 |
+
return output, nt_in, nt_out, WFO_record, GEO_record
|
199 |
+
except Exception as e:
|
200 |
+
self.logger.error(f'{e}')
|
201 |
+
self._adjust_config()
|
202 |
+
|
203 |
+
self.logger.info(f"Failed to extract valid JSON after [{ind}] attempts")
|
204 |
+
self.json_report.set_text(text_main=f'Failed to extract valid JSON after [{ind}] attempts')
|
205 |
+
|
206 |
+
self.monitor.stop_monitoring_report_usage()
|
207 |
+
json_report.set_text(text_main=f'LLM call failed')
|
208 |
+
|
209 |
+
self._reset_config()
|
210 |
+
return None, nt_in, nt_out, None, None
|
211 |
+
|
vouchervision/LLM_local_MistralAI_batch.py
ADDED
@@ -0,0 +1,256 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json, torch, transformers, gc
|
2 |
+
from transformers import BitsAndBytesConfig
|
3 |
+
from langchain.output_parsers import RetryWithErrorOutputParser
|
4 |
+
from langchain.prompts import PromptTemplate
|
5 |
+
from langchain_core.output_parsers import JsonOutputParser
|
6 |
+
from huggingface_hub import hf_hub_download
|
7 |
+
from langchain_community.llms.huggingface_pipeline import HuggingFacePipeline
|
8 |
+
|
9 |
+
from utils_LLM import validate_and_align_JSON_keys_with_template, count_tokens, validate_taxonomy_WFO, validate_coordinates_here, remove_colons_and_double_apostrophes, SystemLoadMonitor
|
10 |
+
|
11 |
+
'''
|
12 |
+
https://python.langchain.com/docs/integrations/llms/huggingface_pipelines
|
13 |
+
'''
|
14 |
+
|
15 |
+
from torch.utils.data import Dataset, DataLoader
|
16 |
+
# Dataset for handling prompts
|
17 |
+
class PromptDataset(Dataset):
|
18 |
+
def __init__(self, prompts):
|
19 |
+
self.prompts = prompts
|
20 |
+
|
21 |
+
def __len__(self):
|
22 |
+
return len(self.prompts)
|
23 |
+
|
24 |
+
def __getitem__(self, idx):
|
25 |
+
return self.prompts[idx]
|
26 |
+
|
27 |
+
class LocalMistralHandler:
|
28 |
+
RETRY_DELAY = 2 # Wait 2 seconds before retrying
|
29 |
+
MAX_RETRIES = 5 # Maximum number of retries
|
30 |
+
STARTING_TEMP = 0.1
|
31 |
+
TOKENIZER_NAME = None
|
32 |
+
VENDOR = 'mistral'
|
33 |
+
MAX_GPU_MONITORING_INTERVAL = 2 # seconds
|
34 |
+
|
35 |
+
def __init__(self, logger, model_name, JSON_dict_structure):
|
36 |
+
self.logger = logger
|
37 |
+
self.has_GPU = torch.cuda.is_available()
|
38 |
+
self.monitor = SystemLoadMonitor(logger)
|
39 |
+
|
40 |
+
self.model_name = model_name
|
41 |
+
self.model_id = f"mistralai/{self.model_name}"
|
42 |
+
name_parts = self.model_name.split('-')
|
43 |
+
|
44 |
+
self.model_path = hf_hub_download(repo_id=self.model_id, repo_type="model",filename="config.json")
|
45 |
+
|
46 |
+
|
47 |
+
self.JSON_dict_structure = JSON_dict_structure
|
48 |
+
self.starting_temp = float(self.STARTING_TEMP)
|
49 |
+
self.temp_increment = float(0.2)
|
50 |
+
self.adjust_temp = self.starting_temp
|
51 |
+
|
52 |
+
system_prompt = "You are a helpful AI assistant who answers queries a JSON dictionary as specified by the user."
|
53 |
+
template = """
|
54 |
+
<s>[INST]{}[/INST]</s>
|
55 |
+
|
56 |
+
[INST]{}[/INST]
|
57 |
+
""".format(system_prompt, "{query}")
|
58 |
+
|
59 |
+
# Create a prompt from the template so we can use it with Langchain
|
60 |
+
self.prompt = PromptTemplate(template=template, input_variables=["query"])
|
61 |
+
|
62 |
+
# Set up a parser
|
63 |
+
self.parser = JsonOutputParser()
|
64 |
+
|
65 |
+
self._set_config()
|
66 |
+
|
67 |
+
|
68 |
+
def _clear_VRAM(self):
|
69 |
+
# Clear CUDA cache if it's being used
|
70 |
+
if self.has_GPU:
|
71 |
+
self.local_model = None
|
72 |
+
self.local_model_pipeline = None
|
73 |
+
del self.local_model
|
74 |
+
del self.local_model_pipeline
|
75 |
+
gc.collect() # Explicitly invoke garbage collector
|
76 |
+
torch.cuda.empty_cache()
|
77 |
+
else:
|
78 |
+
self.local_model_pipeline = None
|
79 |
+
self.local_model = None
|
80 |
+
del self.local_model_pipeline
|
81 |
+
del self.local_model
|
82 |
+
gc.collect() # Explicitly invoke garbage collector
|
83 |
+
|
84 |
+
|
85 |
+
def _set_config(self):
|
86 |
+
self._clear_VRAM()
|
87 |
+
self.config = {'max_new_tokens': 1024,
|
88 |
+
'temperature': self.starting_temp,
|
89 |
+
'seed': 2023,
|
90 |
+
'top_p': 1,
|
91 |
+
'top_k': 40,
|
92 |
+
'do_sample': True,
|
93 |
+
'n_ctx':4096,
|
94 |
+
|
95 |
+
# Activate 4-bit precision base model loading
|
96 |
+
'use_4bit': True,
|
97 |
+
# Compute dtype for 4-bit base models
|
98 |
+
'bnb_4bit_compute_dtype': "float16",
|
99 |
+
# Quantization type (fp4 or nf4)
|
100 |
+
'bnb_4bit_quant_type': "nf4",
|
101 |
+
# Activate nested quantization for 4-bit base models (double quantization)
|
102 |
+
'use_nested_quant': False,
|
103 |
+
}
|
104 |
+
|
105 |
+
compute_dtype = getattr(torch,self.config.get('bnb_4bit_compute_dtype') )
|
106 |
+
|
107 |
+
self.bnb_config = BitsAndBytesConfig(
|
108 |
+
load_in_4bit=self.config.get('use_4bit'),
|
109 |
+
bnb_4bit_quant_type=self.config.get('bnb_4bit_quant_type'),
|
110 |
+
bnb_4bit_compute_dtype=compute_dtype,
|
111 |
+
bnb_4bit_use_double_quant=self.config.get('use_nested_quant'),
|
112 |
+
)
|
113 |
+
|
114 |
+
# Check GPU compatibility with bfloat16
|
115 |
+
if compute_dtype == torch.float16 and self.config.get('use_4bit'):
|
116 |
+
major, _ = torch.cuda.get_device_capability()
|
117 |
+
if major >= 8:
|
118 |
+
# print("=" * 80)
|
119 |
+
# print("Your GPU supports bfloat16: accelerate training with bf16=True")
|
120 |
+
# print("=" * 80)
|
121 |
+
self.b_float_opt = torch.bfloat16
|
122 |
+
|
123 |
+
else:
|
124 |
+
self.b_float_opt = torch.float16
|
125 |
+
self._build_model_chain_parser()
|
126 |
+
|
127 |
+
|
128 |
+
def _adjust_config(self):
|
129 |
+
self.logger.info(f'Incrementing temperature and reloading model')
|
130 |
+
self._clear_VRAM()
|
131 |
+
self.adjust_temp += self.temp_increment
|
132 |
+
self.config['temperature'] = self.adjust_temp
|
133 |
+
self._build_model_chain_parser()
|
134 |
+
|
135 |
+
|
136 |
+
def _build_model_chain_parser(self):
|
137 |
+
self.local_model_pipeline = transformers.pipeline("text-generation",
|
138 |
+
model=self.model_id,
|
139 |
+
max_new_tokens=self.config.get('max_new_tokens'),
|
140 |
+
temperature=self.config.get('temperature'),
|
141 |
+
top_k=self.config.get('top_k'),
|
142 |
+
top_p=self.config.get('top_p'),
|
143 |
+
do_sample=self.config.get('do_sample'),
|
144 |
+
model_kwargs={"torch_dtype": self.b_float_opt,
|
145 |
+
"load_in_4bit": True,
|
146 |
+
"quantization_config": self.bnb_config})
|
147 |
+
self.local_model = HuggingFacePipeline(pipeline=self.local_model_pipeline)
|
148 |
+
# Set up the retry parser with the runnable
|
149 |
+
self.retry_parser = RetryWithErrorOutputParser.from_llm(parser=self.parser, llm=self.local_model, max_retries=self.MAX_RETRIES)
|
150 |
+
# Create an llm chain with LLM and prompt
|
151 |
+
self.chain = self.prompt | self.local_model # LCEL
|
152 |
+
|
153 |
+
'''
|
154 |
+
def call_llm_local_MistralAI(self, prompt_template):
|
155 |
+
self.monitor.start_monitoring_usage()
|
156 |
+
|
157 |
+
nt_in = 0
|
158 |
+
nt_out = 0
|
159 |
+
|
160 |
+
ind = 0
|
161 |
+
while (ind < self.MAX_RETRIES):
|
162 |
+
ind += 1
|
163 |
+
# Invoke the chain to generate prompt text
|
164 |
+
results = self.chain.invoke({"query": prompt_template})
|
165 |
+
|
166 |
+
# Use retry_parser to parse the response with retry logic
|
167 |
+
output = self.retry_parser.parse_with_prompt(results, prompt_value=prompt_template)
|
168 |
+
|
169 |
+
if output is None:
|
170 |
+
self.logger.error(f'Failed to extract JSON from:\n{results}')
|
171 |
+
self._adjust_config()
|
172 |
+
del results
|
173 |
+
|
174 |
+
else:
|
175 |
+
nt_in = count_tokens(prompt_template, self.VENDOR, self.TOKENIZER_NAME)
|
176 |
+
nt_out = count_tokens(results, self.VENDOR, self.TOKENIZER_NAME)
|
177 |
+
|
178 |
+
output = validate_and_align_JSON_keys_with_template(output, self.JSON_dict_structure)
|
179 |
+
output, WFO_record = validate_taxonomy_WFO(output, replace_if_success_wfo=False) # Make configurable if needed
|
180 |
+
output, GEO_record = validate_coordinates_here(output, replace_if_success_geo=False) # Make configurable if needed
|
181 |
+
|
182 |
+
self.logger.info(f"Formatted JSON:\n{json.dumps(output,indent=4)}")
|
183 |
+
|
184 |
+
self.monitor.stop_monitoring_report_usage()
|
185 |
+
|
186 |
+
if self.adjust_temp != self.starting_temp:
|
187 |
+
self._set_config()
|
188 |
+
|
189 |
+
del results
|
190 |
+
return output, nt_in, nt_out, WFO_record, GEO_record
|
191 |
+
|
192 |
+
self.logger.info(f"Failed to extract valid JSON after [{ind}] attempts")
|
193 |
+
|
194 |
+
self.monitor.stop_monitoring_report_usage()
|
195 |
+
|
196 |
+
self._set_config()
|
197 |
+
return None, nt_in, nt_out, None, None
|
198 |
+
'''
|
199 |
+
def call_llm_local_MistralAI(self, prompts, batch_size=4):
|
200 |
+
self.monitor.start_monitoring_usage()
|
201 |
+
|
202 |
+
dataset = PromptDataset(prompts)
|
203 |
+
data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
|
204 |
+
|
205 |
+
all_results = []
|
206 |
+
for batch_prompts in data_loader:
|
207 |
+
batch_results = self._process_batch(batch_prompts)
|
208 |
+
all_results.extend(batch_results)
|
209 |
+
|
210 |
+
self.monitor.stop_monitoring_report_usage()
|
211 |
+
|
212 |
+
if self.adjust_temp != self.starting_temp:
|
213 |
+
self._set_config()
|
214 |
+
|
215 |
+
return all_results
|
216 |
+
|
217 |
+
def _process_batch(self, batch_prompts):
|
218 |
+
batch_results = []
|
219 |
+
for prompt in batch_prompts:
|
220 |
+
output, nt_in, nt_out, WFO_record, GEO_record = self._process_single_prompt(prompt)
|
221 |
+
if output is not None:
|
222 |
+
batch_results.append({
|
223 |
+
"output": output,
|
224 |
+
"nt_in": nt_in,
|
225 |
+
"nt_out": nt_out,
|
226 |
+
"WFO_record": WFO_record,
|
227 |
+
"GEO_record": GEO_record
|
228 |
+
})
|
229 |
+
return batch_results
|
230 |
+
|
231 |
+
def _process_single_prompt(self, prompt_template):
|
232 |
+
nt_in = nt_out = 0
|
233 |
+
ind = 0
|
234 |
+
while ind < self.MAX_RETRIES:
|
235 |
+
ind += 1
|
236 |
+
results = self.chain.invoke({"query": prompt_template})
|
237 |
+
output = self.retry_parser.parse_with_prompt(results, prompt_value=prompt_template)
|
238 |
+
|
239 |
+
if output is None:
|
240 |
+
self.logger.error(f'Failed to extract JSON from:\n{results}')
|
241 |
+
self._adjust_config()
|
242 |
+
del results
|
243 |
+
else:
|
244 |
+
nt_in = count_tokens(prompt_template, self.VENDOR, self.TOKENIZER_NAME)
|
245 |
+
nt_out = count_tokens(results, self.VENDOR, self.TOKENIZER_NAME)
|
246 |
+
|
247 |
+
output = validate_and_align_JSON_keys_with_template(output, self.JSON_dict_structure)
|
248 |
+
output, WFO_record = validate_taxonomy_WFO(output, replace_if_success_wfo=False)
|
249 |
+
output, GEO_record = validate_coordinates_here(output, replace_if_success_geo=False)
|
250 |
+
|
251 |
+
self.logger.info(f"Formatted JSON:\n{json.dumps(output, indent=4)}")
|
252 |
+
del results
|
253 |
+
return output, nt_in, nt_out, WFO_record, GEO_record
|
254 |
+
|
255 |
+
self.logger.info(f"Failed to extract valid JSON after [{ind}] attempts")
|
256 |
+
return None, nt_in, nt_out, None, None
|
vouchervision/LLM_local_MistralAI_batch_async.py
ADDED
@@ -0,0 +1,210 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json, torch, transformers, gc
|
2 |
+
from transformers import BitsAndBytesConfig
|
3 |
+
from langchain.output_parsers import RetryWithErrorOutputParser
|
4 |
+
from langchain.prompts import PromptTemplate
|
5 |
+
from langchain_core.output_parsers import JsonOutputParser
|
6 |
+
from huggingface_hub import hf_hub_download
|
7 |
+
from langchain_community.llms.huggingface_pipeline import HuggingFacePipeline
|
8 |
+
import asyncio
|
9 |
+
|
10 |
+
from utils_LLM import validate_and_align_JSON_keys_with_template, count_tokens, validate_taxonomy_WFO, validate_coordinates_here, remove_colons_and_double_apostrophes, SystemLoadMonitor
|
11 |
+
|
12 |
+
'''
|
13 |
+
https://python.langchain.com/docs/integrations/llms/huggingface_pipelines
|
14 |
+
'''
|
15 |
+
|
16 |
+
from torch.utils.data import Dataset, DataLoader
|
17 |
+
# Dataset for handling prompts
|
18 |
+
class PromptDataset(Dataset):
|
19 |
+
def __init__(self, prompts):
|
20 |
+
self.prompts = prompts
|
21 |
+
|
22 |
+
def __len__(self):
|
23 |
+
return len(self.prompts)
|
24 |
+
|
25 |
+
def __getitem__(self, idx):
|
26 |
+
return self.prompts[idx]
|
27 |
+
|
28 |
+
class LocalMistralHandler:
|
29 |
+
RETRY_DELAY = 2 # Wait 2 seconds before retrying
|
30 |
+
MAX_RETRIES = 5 # Maximum number of retries
|
31 |
+
STARTING_TEMP = 0.1
|
32 |
+
TOKENIZER_NAME = None
|
33 |
+
VENDOR = 'mistral'
|
34 |
+
MAX_GPU_MONITORING_INTERVAL = 2 # seconds
|
35 |
+
|
36 |
+
def __init__(self, logger, model_name, JSON_dict_structure):
|
37 |
+
self.logger = logger
|
38 |
+
self.has_GPU = torch.cuda.is_available()
|
39 |
+
self.monitor = SystemLoadMonitor(logger)
|
40 |
+
|
41 |
+
self.model_name = model_name
|
42 |
+
self.model_id = f"mistralai/{self.model_name}"
|
43 |
+
name_parts = self.model_name.split('-')
|
44 |
+
|
45 |
+
self.model_path = hf_hub_download(repo_id=self.model_id, repo_type="model",filename="config.json")
|
46 |
+
|
47 |
+
|
48 |
+
self.JSON_dict_structure = JSON_dict_structure
|
49 |
+
self.starting_temp = float(self.STARTING_TEMP)
|
50 |
+
self.temp_increment = float(0.2)
|
51 |
+
self.adjust_temp = self.starting_temp
|
52 |
+
|
53 |
+
system_prompt = "You are a helpful AI assistant who answers queries a JSON dictionary as specified by the user."
|
54 |
+
template = """
|
55 |
+
<s>[INST]{}[/INST]</s>
|
56 |
+
|
57 |
+
[INST]{}[/INST]
|
58 |
+
""".format(system_prompt, "{query}")
|
59 |
+
|
60 |
+
# Create a prompt from the template so we can use it with Langchain
|
61 |
+
self.prompt = PromptTemplate(template=template, input_variables=["query"])
|
62 |
+
|
63 |
+
# Set up a parser
|
64 |
+
self.parser = JsonOutputParser()
|
65 |
+
|
66 |
+
self._set_config()
|
67 |
+
|
68 |
+
|
69 |
+
def _clear_VRAM(self):
|
70 |
+
# Clear CUDA cache if it's being used
|
71 |
+
if self.has_GPU:
|
72 |
+
self.local_model = None
|
73 |
+
self.local_model_pipeline = None
|
74 |
+
del self.local_model
|
75 |
+
del self.local_model_pipeline
|
76 |
+
gc.collect() # Explicitly invoke garbage collector
|
77 |
+
torch.cuda.empty_cache()
|
78 |
+
else:
|
79 |
+
self.local_model_pipeline = None
|
80 |
+
self.local_model = None
|
81 |
+
del self.local_model_pipeline
|
82 |
+
del self.local_model
|
83 |
+
gc.collect() # Explicitly invoke garbage collector
|
84 |
+
|
85 |
+
|
86 |
+
def _set_config(self):
|
87 |
+
self._clear_VRAM()
|
88 |
+
self.config = {'max_new_tokens': 1024,
|
89 |
+
'temperature': self.starting_temp,
|
90 |
+
'seed': 2023,
|
91 |
+
'top_p': 1,
|
92 |
+
'top_k': 40,
|
93 |
+
'do_sample': True,
|
94 |
+
'n_ctx':4096,
|
95 |
+
|
96 |
+
# Activate 4-bit precision base model loading
|
97 |
+
'use_4bit': True,
|
98 |
+
# Compute dtype for 4-bit base models
|
99 |
+
'bnb_4bit_compute_dtype': "float16",
|
100 |
+
# Quantization type (fp4 or nf4)
|
101 |
+
'bnb_4bit_quant_type': "nf4",
|
102 |
+
# Activate nested quantization for 4-bit base models (double quantization)
|
103 |
+
'use_nested_quant': False,
|
104 |
+
}
|
105 |
+
|
106 |
+
compute_dtype = getattr(torch,self.config.get('bnb_4bit_compute_dtype') )
|
107 |
+
|
108 |
+
self.bnb_config = BitsAndBytesConfig(
|
109 |
+
load_in_4bit=self.config.get('use_4bit'),
|
110 |
+
bnb_4bit_quant_type=self.config.get('bnb_4bit_quant_type'),
|
111 |
+
bnb_4bit_compute_dtype=compute_dtype,
|
112 |
+
bnb_4bit_use_double_quant=self.config.get('use_nested_quant'),
|
113 |
+
)
|
114 |
+
|
115 |
+
# Check GPU compatibility with bfloat16
|
116 |
+
if compute_dtype == torch.float16 and self.config.get('use_4bit'):
|
117 |
+
major, _ = torch.cuda.get_device_capability()
|
118 |
+
if major >= 8:
|
119 |
+
# print("=" * 80)
|
120 |
+
# print("Your GPU supports bfloat16: accelerate training with bf16=True")
|
121 |
+
# print("=" * 80)
|
122 |
+
self.b_float_opt = torch.bfloat16
|
123 |
+
|
124 |
+
else:
|
125 |
+
self.b_float_opt = torch.float16
|
126 |
+
self._build_model_chain_parser()
|
127 |
+
|
128 |
+
|
129 |
+
def _adjust_config(self):
|
130 |
+
self.logger.info(f'Incrementing temperature and reloading model')
|
131 |
+
self._clear_VRAM()
|
132 |
+
self.adjust_temp += self.temp_increment
|
133 |
+
self.config['temperature'] = self.adjust_temp
|
134 |
+
self._build_model_chain_parser()
|
135 |
+
|
136 |
+
|
137 |
+
def _build_model_chain_parser(self):
|
138 |
+
self.local_model_pipeline = transformers.pipeline("text-generation",
|
139 |
+
model=self.model_id,
|
140 |
+
max_new_tokens=self.config.get('max_new_tokens'),
|
141 |
+
temperature=self.config.get('temperature'),
|
142 |
+
top_k=self.config.get('top_k'),
|
143 |
+
top_p=self.config.get('top_p'),
|
144 |
+
do_sample=self.config.get('do_sample'),
|
145 |
+
model_kwargs={"torch_dtype": self.b_float_opt,
|
146 |
+
"load_in_4bit": True,
|
147 |
+
"quantization_config": self.bnb_config})
|
148 |
+
self.local_model = HuggingFacePipeline(pipeline=self.local_model_pipeline)
|
149 |
+
# Set up the retry parser with the runnable
|
150 |
+
self.retry_parser = RetryWithErrorOutputParser.from_llm(parser=self.parser, llm=self.local_model, max_retries=self.MAX_RETRIES)
|
151 |
+
# Create an llm chain with LLM and prompt
|
152 |
+
self.chain = self.prompt | self.local_model
|
153 |
+
|
154 |
+
def call_llm_local_MistralAI(self, prompts, batch_size=2):
|
155 |
+
# Wrap the async call with asyncio.run
|
156 |
+
|
157 |
+
dataset = PromptDataset(prompts)
|
158 |
+
data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
|
159 |
+
|
160 |
+
all_results = asyncio.run(self._process_all_batches(data_loader))
|
161 |
+
|
162 |
+
|
163 |
+
if self.adjust_temp != self.starting_temp:
|
164 |
+
self._set_config()
|
165 |
+
|
166 |
+
return all_results
|
167 |
+
|
168 |
+
async def _process_batch(self, batch_prompts):
|
169 |
+
# Create and manage async tasks for each prompt in the batch
|
170 |
+
tasks = [self._process_single_prompt(prompt) for prompt in batch_prompts]
|
171 |
+
return await asyncio.gather(*tasks)
|
172 |
+
|
173 |
+
async def _process_all_batches(self, data_loader):
|
174 |
+
# Process all batches asynchronously
|
175 |
+
results = []
|
176 |
+
for batch_prompts in data_loader:
|
177 |
+
batch_results = await self._process_batch(batch_prompts)
|
178 |
+
results.extend(batch_results)
|
179 |
+
return results
|
180 |
+
|
181 |
+
async def _process_single_prompt(self, prompt_template):
|
182 |
+
self.monitor.start_monitoring_usage()
|
183 |
+
nt_in = nt_out = 0
|
184 |
+
ind = 0
|
185 |
+
while ind < self.MAX_RETRIES:
|
186 |
+
ind += 1
|
187 |
+
results = self.chain.invoke({"query": prompt_template})
|
188 |
+
output = self.retry_parser.parse_with_prompt(results, prompt_value=prompt_template)
|
189 |
+
|
190 |
+
if output is None:
|
191 |
+
self.logger.error(f'Failed to extract JSON from:\n{results}')
|
192 |
+
self._adjust_config()
|
193 |
+
del results
|
194 |
+
else:
|
195 |
+
nt_in = count_tokens(prompt_template, self.VENDOR, self.TOKENIZER_NAME)
|
196 |
+
nt_out = count_tokens(results, self.VENDOR, self.TOKENIZER_NAME)
|
197 |
+
|
198 |
+
output = validate_and_align_JSON_keys_with_template(output, self.JSON_dict_structure)
|
199 |
+
output, WFO_record = validate_taxonomy_WFO(output, replace_if_success_wfo=False)
|
200 |
+
output, GEO_record = validate_coordinates_here(output, replace_if_success_geo=False)
|
201 |
+
|
202 |
+
self.logger.info(f"Formatted JSON:\n{json.dumps(output, indent=4)}")
|
203 |
+
del results
|
204 |
+
self.monitor.stop_monitoring_report_usage()
|
205 |
+
|
206 |
+
return output, nt_in, nt_out, WFO_record, GEO_record
|
207 |
+
self.monitor.stop_monitoring_report_usage()
|
208 |
+
|
209 |
+
self.logger.info(f"Failed to extract valid JSON after [{ind}] attempts")
|
210 |
+
return None, nt_in, nt_out, None, None
|
vouchervision/LLM_local_cpu_MistralAI.py
ADDED
@@ -0,0 +1,205 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os, json, gc
|
2 |
+
import time
|
3 |
+
import torch
|
4 |
+
import transformers
|
5 |
+
import random
|
6 |
+
from transformers import BitsAndBytesConfig#, AutoModelForCausalLM, AutoTokenizer
|
7 |
+
from langchain.output_parsers import RetryWithErrorOutputParser
|
8 |
+
from langchain.prompts import PromptTemplate
|
9 |
+
from langchain_core.output_parsers import JsonOutputParser
|
10 |
+
from langchain_experimental.llms import JsonFormer
|
11 |
+
from langchain.tools import tool
|
12 |
+
# from langchain_community.llms import CTransformers
|
13 |
+
# from ctransformers import AutoModelForCausalLM, AutoConfig, Config
|
14 |
+
|
15 |
+
from langchain_community.llms import LlamaCpp
|
16 |
+
# from langchain.callbacks.manager import CallbackManager
|
17 |
+
from langchain.callbacks.base import BaseCallbackHandler
|
18 |
+
from huggingface_hub import hf_hub_download
|
19 |
+
|
20 |
+
|
21 |
+
from vouchervision.utils_LLM import SystemLoadMonitor, count_tokens
|
22 |
+
from vouchervision.utils_LLM_JSON_validation import validate_and_align_JSON_keys_with_template
|
23 |
+
from vouchervision.utils_taxonomy_WFO import validate_taxonomy_WFO
|
24 |
+
from vouchervision.utils_geolocate_HERE import validate_coordinates_here
|
25 |
+
|
26 |
+
class LocalCPUMistralHandler:
|
27 |
+
RETRY_DELAY = 2 # Wait 2 seconds before retrying
|
28 |
+
MAX_RETRIES = 5 # Maximum number of retries
|
29 |
+
STARTING_TEMP = 0.1
|
30 |
+
TOKENIZER_NAME = None
|
31 |
+
VENDOR = 'mistral'
|
32 |
+
SEED = 2023
|
33 |
+
|
34 |
+
|
35 |
+
def __init__(self, logger, model_name, JSON_dict_structure):
|
36 |
+
self.logger = logger
|
37 |
+
self.monitor = SystemLoadMonitor(logger)
|
38 |
+
self.has_GPU = torch.cuda.is_available()
|
39 |
+
self.JSON_dict_structure = JSON_dict_structure
|
40 |
+
|
41 |
+
self.model_file = None
|
42 |
+
self.model_name = model_name
|
43 |
+
|
44 |
+
# https://medium.com/@scholarly360/mistral-7b-complete-guide-on-colab-129fa5e9a04d
|
45 |
+
self.model_name = "Mistral-7B-Instruct-v0.2-GGUF" #huggingface-cli download TheBloke/Mistral-7B-Instruct-v0.2-GGUF mistral-7b-instruct-v0.2.Q4_K_M.gguf --local-dir /home/brlab/.cache --local-dir-use-symlinks False
|
46 |
+
self.model_id = f"TheBloke/{self.model_name}"
|
47 |
+
name_parts = self.model_name.split('-')
|
48 |
+
|
49 |
+
if self.model_name == "Mistral-7B-Instruct-v0.2-GGUF":
|
50 |
+
self.model_file = 'mistral-7b-instruct-v0.2.Q4_K_M.gguf'
|
51 |
+
self.model_path = hf_hub_download(repo_id=self.model_id,
|
52 |
+
filename=self.model_file,
|
53 |
+
repo_type="model")
|
54 |
+
else:
|
55 |
+
raise f"Unsupported GGUF model name"
|
56 |
+
|
57 |
+
# self.model_id = f"mistralai/{self.model_name}"
|
58 |
+
self.gpu_usage = {'max_load': 0, 'max_memory_usage': 0, 'monitoring': True}
|
59 |
+
|
60 |
+
self.starting_temp = float(self.STARTING_TEMP)
|
61 |
+
self.temp_increment = float(0.2)
|
62 |
+
self.adjust_temp = self.starting_temp
|
63 |
+
|
64 |
+
system_prompt = "You are a helpful AI assistant who answers queries with JSON objects and no explanations."
|
65 |
+
template = """
|
66 |
+
<s>[INST]{}[/INST]</s>
|
67 |
+
|
68 |
+
[INST]{}[/INST]
|
69 |
+
""".format(system_prompt, "{query}")
|
70 |
+
|
71 |
+
# Create a prompt from the template so we can use it with Langchain
|
72 |
+
self.prompt = PromptTemplate(template=template, input_variables=["query"])
|
73 |
+
|
74 |
+
# Set up a parser
|
75 |
+
self.parser = JsonOutputParser()
|
76 |
+
|
77 |
+
self._set_config()
|
78 |
+
|
79 |
+
|
80 |
+
# def _clear_VRAM(self):
|
81 |
+
# # Clear CUDA cache if it's being used
|
82 |
+
# if self.has_GPU:
|
83 |
+
# self.local_model = None
|
84 |
+
# del self.local_model
|
85 |
+
# gc.collect() # Explicitly invoke garbage collector
|
86 |
+
# torch.cuda.empty_cache()
|
87 |
+
# else:
|
88 |
+
# self.local_model = None
|
89 |
+
# del self.local_model
|
90 |
+
# gc.collect() # Explicitly invoke garbage collector
|
91 |
+
|
92 |
+
|
93 |
+
def _set_config(self):
|
94 |
+
# self._clear_VRAM()
|
95 |
+
self.config = {'max_new_tokens': 1024,
|
96 |
+
'temperature': self.starting_temp,
|
97 |
+
'seed': self.SEED,
|
98 |
+
'top_p': 1,
|
99 |
+
'top_k': 40,
|
100 |
+
'n_ctx': 4096,
|
101 |
+
'do_sample': True,
|
102 |
+
}
|
103 |
+
self._build_model_chain_parser()
|
104 |
+
|
105 |
+
|
106 |
+
def _adjust_config(self):
|
107 |
+
new_temp = self.adjust_temp + self.temp_increment
|
108 |
+
self.json_report.set_text(text_main=f'Incrementing temperature from {self.adjust_temp} to {new_temp}')
|
109 |
+
self.logger.info(f'Incrementing temperature from {self.adjust_temp} to {new_temp}')
|
110 |
+
self.adjust_temp += self.temp_increment
|
111 |
+
self.config['temperature'] = self.adjust_temp
|
112 |
+
|
113 |
+
def _reset_config(self):
|
114 |
+
self.json_report.set_text(text_main=f'Resetting temperature from {self.adjust_temp} to {self.starting_temp}')
|
115 |
+
self.logger.info(f'Resetting temperature from {self.adjust_temp} to {self.starting_temp}')
|
116 |
+
self.adjust_temp = self.starting_temp
|
117 |
+
self.config['temperature'] = self.starting_temp
|
118 |
+
|
119 |
+
|
120 |
+
def _build_model_chain_parser(self):
|
121 |
+
self.local_model = LlamaCpp(
|
122 |
+
model_path=self.model_path,
|
123 |
+
max_tokens=self.config.get('max_new_tokens'),
|
124 |
+
top_p=self.config.get('top_p'),
|
125 |
+
# callback_manager=callback_manager,
|
126 |
+
# n_gpu_layers=1,
|
127 |
+
# n_batch=512,
|
128 |
+
n_ctx=self.config.get('n_ctx'),
|
129 |
+
stop=["[INST]"],
|
130 |
+
verbose=False,
|
131 |
+
streaming=False,
|
132 |
+
)
|
133 |
+
# Set up the retry parser with the runnable
|
134 |
+
self.retry_parser = RetryWithErrorOutputParser.from_llm(parser=self.parser, llm=self.local_model, max_retries=self.MAX_RETRIES)
|
135 |
+
# Create an llm chain with LLM and prompt
|
136 |
+
self.chain = self.prompt | self.local_model
|
137 |
+
|
138 |
+
|
139 |
+
def call_llm_local_cpu_MistralAI(self, prompt_template, json_report):
|
140 |
+
self.json_report = json_report
|
141 |
+
self.json_report.set_text(text_main=f'Sending request to {self.model_name}')
|
142 |
+
self.monitor.start_monitoring_usage()
|
143 |
+
|
144 |
+
nt_in = 0
|
145 |
+
nt_out = 0
|
146 |
+
|
147 |
+
ind = 0
|
148 |
+
while ind < self.MAX_RETRIES:
|
149 |
+
ind += 1
|
150 |
+
try:
|
151 |
+
### BELOW IS BASIC MISTRAL CALL
|
152 |
+
# mistral_prompt = f"<s>[INST] {prompt_template} [/INST]"
|
153 |
+
# results = self.local_model(mistral_prompt, temperature = 0.7,
|
154 |
+
# repetition_penalty = 1.15,
|
155 |
+
# max_new_tokens = 2048)
|
156 |
+
# print(results)
|
157 |
+
|
158 |
+
model_kwargs = {"temperature": self.adjust_temp}
|
159 |
+
|
160 |
+
# Invoke the chain to generate prompt text
|
161 |
+
results = self.chain.invoke({"query": prompt_template, "model_kwargs": model_kwargs})
|
162 |
+
|
163 |
+
# Use retry_parser to parse the response with retry logic
|
164 |
+
output = self.retry_parser.parse_with_prompt(results, prompt_value=prompt_template)
|
165 |
+
|
166 |
+
if output is None:
|
167 |
+
self.logger.error(f'[Attempt {ind}] Failed to extract JSON from:\n{results}')
|
168 |
+
self._adjust_config()
|
169 |
+
|
170 |
+
else:
|
171 |
+
nt_in = count_tokens(prompt_template, self.VENDOR, self.TOKENIZER_NAME)
|
172 |
+
nt_out = count_tokens(results, self.VENDOR, self.TOKENIZER_NAME)
|
173 |
+
|
174 |
+
output = validate_and_align_JSON_keys_with_template(output, self.JSON_dict_structure)
|
175 |
+
if output is None:
|
176 |
+
self.logger.error(f'[Attempt {ind}] Failed to extract JSON from:\n{results}')
|
177 |
+
self._adjust_config()
|
178 |
+
else:
|
179 |
+
json_report.set_text(text_main=f'Working on WFO and Geolocation')
|
180 |
+
output, WFO_record = validate_taxonomy_WFO(output, replace_if_success_wfo=False) # Make configurable if needed
|
181 |
+
output, GEO_record = validate_coordinates_here(output, replace_if_success_geo=False) # Make configurable if needed
|
182 |
+
|
183 |
+
self.logger.info(f"Formatted JSON:\n{json.dumps(output,indent=4)}")
|
184 |
+
|
185 |
+
self.monitor.stop_monitoring_report_usage()
|
186 |
+
|
187 |
+
if self.adjust_temp != self.starting_temp:
|
188 |
+
self._reset_config()
|
189 |
+
json_report.set_text(text_main=f'LLM call successful')
|
190 |
+
return output, nt_in, nt_out, WFO_record, GEO_record
|
191 |
+
|
192 |
+
except Exception as e:
|
193 |
+
self.logger.error(f'{e}')
|
194 |
+
self._adjust_config()
|
195 |
+
|
196 |
+
self.logger.info(f"Failed to extract valid JSON after [{ind}] attempts")
|
197 |
+
self.json_report.set_text(text_main=f'Failed to extract valid JSON after [{ind}] attempts')
|
198 |
+
|
199 |
+
self.monitor.stop_monitoring_report_usage()
|
200 |
+
self._reset_config()
|
201 |
+
|
202 |
+
json_report.set_text(text_main=f'LLM call failed')
|
203 |
+
return None, nt_in, nt_out, None, None
|
204 |
+
|
205 |
+
|
vouchervision/LM2_logger.py
CHANGED
@@ -1,9 +1,20 @@
|
|
1 |
import logging, os, psutil, torch, platform, cpuinfo, yaml #py-cpuinfo
|
2 |
from vouchervision.general_utils import get_datetime, print_main_warn, print_main_info
|
3 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
4 |
def start_logging(Dirs, cfg):
|
5 |
run_name = cfg['leafmachine']['project']['run_name']
|
6 |
-
path_log = os.path.join(Dirs.path_log, '__'.join(['LM2-log',str(get_datetime()), run_name])+'.log')
|
7 |
|
8 |
# Disable default StreamHandler
|
9 |
logging.getLogger().handlers = []
|
@@ -12,9 +23,9 @@ def start_logging(Dirs, cfg):
|
|
12 |
logger = logging.getLogger('Hardware Components')
|
13 |
logger.setLevel(logging.DEBUG)
|
14 |
|
15 |
-
# create file handler and set level to debug
|
16 |
-
|
17 |
-
|
18 |
|
19 |
# create console handler and set level to debug
|
20 |
ch = logging.StreamHandler()
|
@@ -24,11 +35,11 @@ def start_logging(Dirs, cfg):
|
|
24 |
formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(name)s - %(message)s')
|
25 |
|
26 |
# add formatter to handlers
|
27 |
-
|
28 |
ch.setFormatter(formatter)
|
29 |
|
30 |
# add handlers to logger
|
31 |
-
logger.addHandler(
|
32 |
logger.addHandler(ch)
|
33 |
|
34 |
# Create a logger for the file handler
|
|
|
1 |
import logging, os, psutil, torch, platform, cpuinfo, yaml #py-cpuinfo
|
2 |
from vouchervision.general_utils import get_datetime, print_main_warn, print_main_info
|
3 |
|
4 |
+
class SanitizingFileHandler(logging.FileHandler):
|
5 |
+
def __init__(self, filename, mode='a', encoding=None, delay=False):
|
6 |
+
super().__init__(filename, mode, encoding, delay)
|
7 |
+
|
8 |
+
def emit(self, record):
|
9 |
+
try:
|
10 |
+
record.msg = record.msg.encode('utf-8', 'replace').decode('utf-8')
|
11 |
+
except Exception as e:
|
12 |
+
record.msg = f'[Error encoding text: {e}]'
|
13 |
+
super().emit(record)
|
14 |
+
|
15 |
def start_logging(Dirs, cfg):
|
16 |
run_name = cfg['leafmachine']['project']['run_name']
|
17 |
+
path_log = os.path.join(Dirs.path_log, '__'.join(['LM2-log', str(get_datetime()), run_name]) + '.log')
|
18 |
|
19 |
# Disable default StreamHandler
|
20 |
logging.getLogger().handlers = []
|
|
|
23 |
logger = logging.getLogger('Hardware Components')
|
24 |
logger.setLevel(logging.DEBUG)
|
25 |
|
26 |
+
# create custom sanitizing file handler and set level to debug
|
27 |
+
sanitizing_fh = SanitizingFileHandler(path_log, encoding='utf-8')
|
28 |
+
sanitizing_fh.setLevel(logging.DEBUG)
|
29 |
|
30 |
# create console handler and set level to debug
|
31 |
ch = logging.StreamHandler()
|
|
|
35 |
formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(name)s - %(message)s')
|
36 |
|
37 |
# add formatter to handlers
|
38 |
+
sanitizing_fh.setFormatter(formatter)
|
39 |
ch.setFormatter(formatter)
|
40 |
|
41 |
# add handlers to logger
|
42 |
+
logger.addHandler(sanitizing_fh)
|
43 |
logger.addHandler(ch)
|
44 |
|
45 |
# Create a logger for the file handler
|
vouchervision/OCR_google_cloud_vision.py
CHANGED
@@ -1,12 +1,576 @@
|
|
1 |
-
import os, io, sys, inspect
|
2 |
-
from
|
3 |
-
from
|
|
|
|
|
|
|
|
|
|
|
4 |
|
5 |
currentdir = os.path.dirname(os.path.abspath(
|
6 |
inspect.getfile(inspect.currentframe())))
|
7 |
parentdir = os.path.dirname(currentdir)
|
8 |
sys.path.append(parentdir)
|
9 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
10 |
def draw_boxes(image, bounds, color):
|
11 |
if bounds:
|
12 |
draw = ImageDraw.Draw(image)
|
@@ -26,8 +590,8 @@ def draw_boxes(image, bounds, color):
|
|
26 |
)
|
27 |
return image
|
28 |
|
29 |
-
def detect_text(path
|
30 |
-
|
31 |
with io.open(path, 'rb') as image_file:
|
32 |
content = image_file.read()
|
33 |
image = vision.Image(content=content)
|
@@ -60,34 +624,127 @@ def detect_text(path, client):
|
|
60 |
else:
|
61 |
return '', None, None
|
62 |
|
63 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
64 |
if do_create_OCR_helper_image:
|
65 |
image = Image.open(path)
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
image = Image.open(path)
|
70 |
-
return image
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
81 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
82 |
|
|
|
|
|
83 |
|
84 |
|
|
|
|
|
|
|
|
|
|
|
|
|
85 |
|
|
|
|
|
86 |
|
|
|
|
|
|
|
|
|
87 |
|
|
|
|
|
88 |
|
|
|
|
|
|
|
89 |
|
|
|
|
|
|
|
|
|
|
|
90 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
91 |
|
92 |
# ''' Google Vision'''
|
93 |
# def detect_text(path):
|
|
|
1 |
+
import os, io, sys, inspect, statistics
|
2 |
+
from statistics import mean
|
3 |
+
# from google.cloud import vision, storage
|
4 |
+
from google.cloud import vision
|
5 |
+
from google.cloud import vision_v1p3beta1 as vision_beta
|
6 |
+
from PIL import Image, ImageDraw, ImageFont
|
7 |
+
import colorsys
|
8 |
+
from tqdm import tqdm
|
9 |
|
10 |
currentdir = os.path.dirname(os.path.abspath(
|
11 |
inspect.getfile(inspect.currentframe())))
|
12 |
parentdir = os.path.dirname(currentdir)
|
13 |
sys.path.append(parentdir)
|
14 |
|
15 |
+
|
16 |
+
'''
|
17 |
+
@misc{li2021trocr,
|
18 |
+
title={TrOCR: Transformer-based Optical Character Recognition with Pre-trained Models},
|
19 |
+
author={Minghao Li and Tengchao Lv and Lei Cui and Yijuan Lu and Dinei Florencio and Cha Zhang and Zhoujun Li and Furu Wei},
|
20 |
+
year={2021},
|
21 |
+
eprint={2109.10282},
|
22 |
+
archivePrefix={arXiv},
|
23 |
+
primaryClass={cs.CL}
|
24 |
+
}
|
25 |
+
'''
|
26 |
+
|
27 |
+
class OCRGoogle:
|
28 |
+
|
29 |
+
BBOX_COLOR = "black"
|
30 |
+
|
31 |
+
def __init__(self, path, cfg, trOCR_model_version, trOCR_model, trOCR_processor, device):
|
32 |
+
self.path = path
|
33 |
+
self.cfg = cfg
|
34 |
+
self.do_use_trOCR = self.cfg['leafmachine']['project']['do_use_trOCR']
|
35 |
+
self.OCR_option = self.cfg['leafmachine']['project']['OCR_option']
|
36 |
+
|
37 |
+
# Initialize TrOCR components
|
38 |
+
self.trOCR_model_version = trOCR_model_version
|
39 |
+
self.trOCR_processor = trOCR_processor
|
40 |
+
self.trOCR_model = trOCR_model
|
41 |
+
self.device = device
|
42 |
+
|
43 |
+
self.hand_cleaned_text = None
|
44 |
+
self.hand_organized_text = None
|
45 |
+
self.hand_bounds = None
|
46 |
+
self.hand_bounds_word = None
|
47 |
+
self.hand_bounds_flat = None
|
48 |
+
self.hand_text_to_box_mapping = None
|
49 |
+
self.hand_height = None
|
50 |
+
self.hand_confidences = None
|
51 |
+
self.hand_characters = None
|
52 |
+
|
53 |
+
self.normal_cleaned_text = None
|
54 |
+
self.normal_organized_text = None
|
55 |
+
self.normal_bounds = None
|
56 |
+
self.normal_bounds_word = None
|
57 |
+
self.normal_text_to_box_mapping = None
|
58 |
+
self.normal_bounds_flat = None
|
59 |
+
self.normal_height = None
|
60 |
+
self.normal_confidences = None
|
61 |
+
self.normal_characters = None
|
62 |
+
|
63 |
+
self.trOCR_texts = None
|
64 |
+
self.trOCR_text_to_box_mapping = None
|
65 |
+
self.trOCR_bounds_flat = None
|
66 |
+
self.trOCR_height = None
|
67 |
+
self.trOCR_confidences = None
|
68 |
+
self.trOCR_characters = None
|
69 |
+
|
70 |
+
def detect_text_with_trOCR_using_google_bboxes(self, do_use_trOCR, logger):
|
71 |
+
CONFIDENCES = 0.80
|
72 |
+
MAX_NEW_TOKENS = 50
|
73 |
+
|
74 |
+
self.OCR_JSON_to_file = {}
|
75 |
+
|
76 |
+
if not do_use_trOCR:
|
77 |
+
if self.OCR_option in ['normal',]:
|
78 |
+
self.OCR_JSON_to_file['OCR_printed'] = self.normal_organized_text
|
79 |
+
logger.info(f"Google_OCR_Standard:\n{self.normal_organized_text}")
|
80 |
+
return f"Google_OCR_Standard:\n{self.normal_organized_text}"
|
81 |
+
|
82 |
+
if self.OCR_option in ['hand',]:
|
83 |
+
self.OCR_JSON_to_file['OCR_handwritten'] = self.hand_organized_text
|
84 |
+
logger.info(f"Google_OCR_Handwriting:\n{self.hand_organized_text}")
|
85 |
+
return f"Google_OCR_Handwriting:\n{self.hand_organized_text}"
|
86 |
+
|
87 |
+
if self.OCR_option in ['both',]:
|
88 |
+
logger.info(f"Google_OCR_Standard:\n{self.normal_organized_text}\n\nGoogle_OCR_Handwriting:\n{self.hand_organized_text}")
|
89 |
+
return f"Google_OCR_Standard:\n{self.normal_organized_text}\n\nGoogle_OCR_Handwriting:\n{self.hand_organized_text}"
|
90 |
+
|
91 |
+
else:
|
92 |
+
logger.info(f'Supplementing with trOCR')
|
93 |
+
|
94 |
+
self.trOCR_texts = []
|
95 |
+
original_image = Image.open(self.path).convert("RGB")
|
96 |
+
|
97 |
+
if self.OCR_option in ['normal',]:
|
98 |
+
available_bounds = self.normal_bounds_word
|
99 |
+
elif self.OCR_option in ['hand',]:
|
100 |
+
available_bounds = self.hand_bounds_word
|
101 |
+
elif self.OCR_option in ['both',]:
|
102 |
+
available_bounds = self.hand_bounds_word
|
103 |
+
else:
|
104 |
+
raise
|
105 |
+
|
106 |
+
text_to_box_mapping = []
|
107 |
+
characters = []
|
108 |
+
height = []
|
109 |
+
confidences = []
|
110 |
+
for bound in tqdm(available_bounds, desc="Processing words using Google Vision bboxes"):
|
111 |
+
vertices = bound["vertices"]
|
112 |
+
|
113 |
+
|
114 |
+
left = min([v["x"] for v in vertices])
|
115 |
+
top = min([v["y"] for v in vertices])
|
116 |
+
right = max([v["x"] for v in vertices])
|
117 |
+
bottom = max([v["y"] for v in vertices])
|
118 |
+
|
119 |
+
# Crop image based on Google's bounding box
|
120 |
+
cropped_image = original_image.crop((left, top, right, bottom))
|
121 |
+
pixel_values = self.trOCR_processor(cropped_image, return_tensors="pt").pixel_values
|
122 |
+
|
123 |
+
# Move pixel values to the appropriate device
|
124 |
+
pixel_values = pixel_values.to(self.device)
|
125 |
+
|
126 |
+
generated_ids = self.trOCR_model.generate(pixel_values, max_new_tokens=MAX_NEW_TOKENS)
|
127 |
+
extracted_text = self.trOCR_processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
128 |
+
self.trOCR_texts.append(extracted_text)
|
129 |
+
|
130 |
+
# For plotting
|
131 |
+
word_length = max(vertex.get('x') for vertex in vertices) - min(vertex.get('x') for vertex in vertices)
|
132 |
+
num_symbols = len(extracted_text)
|
133 |
+
|
134 |
+
Yw = max(vertex.get('y') for vertex in vertices)
|
135 |
+
Yo = Yw - min(vertex.get('y') for vertex in vertices)
|
136 |
+
X = word_length / num_symbols if num_symbols > 0 else 0
|
137 |
+
H = int(X+(Yo*0.1))
|
138 |
+
height.append(H)
|
139 |
+
|
140 |
+
map_dict = {
|
141 |
+
"vertices": vertices,
|
142 |
+
"text": extracted_text # Use the text extracted by trOCR
|
143 |
+
}
|
144 |
+
text_to_box_mapping.append(map_dict)
|
145 |
+
|
146 |
+
characters.append(extracted_text)
|
147 |
+
confidences.append(CONFIDENCES)
|
148 |
+
|
149 |
+
median_height = statistics.median(height) if height else 0
|
150 |
+
median_heights = [median_height * 1.5] * len(characters)
|
151 |
+
|
152 |
+
self.trOCR_texts = ' '.join(self.trOCR_texts)
|
153 |
+
|
154 |
+
self.trOCR_text_to_box_mapping = text_to_box_mapping
|
155 |
+
self.trOCR_bounds_flat = available_bounds
|
156 |
+
self.trOCR_height = median_heights
|
157 |
+
self.trOCR_confidences = confidences
|
158 |
+
self.trOCR_characters = characters
|
159 |
+
|
160 |
+
if self.OCR_option in ['normal',]:
|
161 |
+
self.OCR_JSON_to_file['OCR_printed'] = self.normal_organized_text
|
162 |
+
self.OCR_JSON_to_file['OCR_trOCR'] = self.trOCR_texts
|
163 |
+
logger.info(f"Google_OCR_Standard:\n{self.normal_organized_text}\n\ntrOCR:\n{self.trOCR_texts}")
|
164 |
+
return f"Google_OCR_Standard:\n{self.normal_organized_text}\n\ntrOCR:\n{self.trOCR_texts}"
|
165 |
+
if self.OCR_option in ['hand',]:
|
166 |
+
self.OCR_JSON_to_file['OCR_handwritten'] = self.hand_organized_text
|
167 |
+
self.OCR_JSON_to_file['OCR_trOCR'] = self.trOCR_texts
|
168 |
+
logger.info(f"Google_OCR_Handwriting:\n{self.hand_organized_text}\n\ntrOCR:\n{self.trOCR_texts}")
|
169 |
+
return f"Google_OCR_Handwriting:\n{self.hand_organized_text}\n\ntrOCR:\n{self.trOCR_texts}"
|
170 |
+
if self.OCR_option in ['both',]:
|
171 |
+
self.OCR_JSON_to_file['OCR_printed'] = self.normal_organized_text
|
172 |
+
self.OCR_JSON_to_file['OCR_handwritten'] = self.hand_organized_text
|
173 |
+
self.OCR_JSON_to_file['OCR_trOCR'] = self.trOCR_texts
|
174 |
+
logger.info(f"Google_OCR_Standard:\n{self.normal_organized_text}\n\nGoogle_OCR_Handwriting:\n{self.hand_organized_text}\n\ntrOCR:\n{self.trOCR_texts}")
|
175 |
+
return f"Google_OCR_Standard:\n{self.normal_organized_text}\n\nGoogle_OCR_Handwriting:\n{self.hand_organized_text}\n\ntrOCR:\n{self.trOCR_texts}"
|
176 |
+
else:
|
177 |
+
raise
|
178 |
+
|
179 |
+
@staticmethod
|
180 |
+
def confidence_to_color(confidence):
|
181 |
+
hue = (confidence - 0.5) * 120 / 0.5
|
182 |
+
r, g, b = colorsys.hls_to_rgb(hue/360, 0.5, 1)
|
183 |
+
return (int(r*255), int(g*255), int(b*255))
|
184 |
+
|
185 |
+
|
186 |
+
def render_text_on_black_image(self, option):
|
187 |
+
bounds_flat = getattr(self, f'{option}_bounds_flat', [])
|
188 |
+
heights = getattr(self, f'{option}_height', [])
|
189 |
+
confidences = getattr(self, f'{option}_confidences', [])
|
190 |
+
characters = getattr(self, f'{option}_characters', [])
|
191 |
+
|
192 |
+
original_image = Image.open(self.path)
|
193 |
+
width, height = original_image.size
|
194 |
+
black_image = Image.new("RGB", (width, height), "black")
|
195 |
+
draw = ImageDraw.Draw(black_image)
|
196 |
+
|
197 |
+
for bound, confidence, char_height, character in zip(bounds_flat, confidences, heights, characters):
|
198 |
+
font_size = int(char_height)
|
199 |
+
font = ImageFont.load_default().font_variant(size=font_size)
|
200 |
+
if option == 'trOCR':
|
201 |
+
color = (0, 170, 255)
|
202 |
+
else:
|
203 |
+
color = OCRGoogle.confidence_to_color(confidence)
|
204 |
+
position = (bound["vertices"][0]["x"], bound["vertices"][0]["y"] - char_height)
|
205 |
+
draw.text(position, character, fill=color, font=font)
|
206 |
+
|
207 |
+
return black_image
|
208 |
+
|
209 |
+
|
210 |
+
def merge_images(self, image1, image2):
|
211 |
+
width1, height1 = image1.size
|
212 |
+
width2, height2 = image2.size
|
213 |
+
merged_image = Image.new("RGB", (width1 + width2, max([height1, height2])))
|
214 |
+
merged_image.paste(image1, (0, 0))
|
215 |
+
merged_image.paste(image2, (width1, 0))
|
216 |
+
return merged_image
|
217 |
+
|
218 |
+
|
219 |
+
def draw_boxes(self, option):
|
220 |
+
bounds = getattr(self, f'{option}_bounds', [])
|
221 |
+
bounds_word = getattr(self, f'{option}_bounds_word', [])
|
222 |
+
confidences = getattr(self, f'{option}_confidences', [])
|
223 |
+
|
224 |
+
draw = ImageDraw.Draw(self.image)
|
225 |
+
width, height = self.image.size
|
226 |
+
if min([width, height]) > 4000:
|
227 |
+
line_width_thick = int((width + height) / 2 * 0.0025) # Adjust line width for character level
|
228 |
+
line_width_thin = 1
|
229 |
+
else:
|
230 |
+
line_width_thick = int((width + height) / 2 * 0.005) # Adjust line width for character level
|
231 |
+
line_width_thin = 1 #int((width + height) / 2 * 0.001)
|
232 |
+
|
233 |
+
for bound in bounds_word:
|
234 |
+
draw.polygon(
|
235 |
+
[
|
236 |
+
bound["vertices"][0]["x"], bound["vertices"][0]["y"],
|
237 |
+
bound["vertices"][1]["x"], bound["vertices"][1]["y"],
|
238 |
+
bound["vertices"][2]["x"], bound["vertices"][2]["y"],
|
239 |
+
bound["vertices"][3]["x"], bound["vertices"][3]["y"],
|
240 |
+
],
|
241 |
+
outline=OCRGoogle.BBOX_COLOR,
|
242 |
+
width=line_width_thin
|
243 |
+
)
|
244 |
+
|
245 |
+
# Draw a line segment at the bottom of each handwritten character
|
246 |
+
for bound, confidence in zip(bounds, confidences):
|
247 |
+
color = OCRGoogle.confidence_to_color(confidence)
|
248 |
+
# Use the bottom two vertices of the bounding box for the line
|
249 |
+
bottom_left = (bound["vertices"][3]["x"], bound["vertices"][3]["y"] + line_width_thick)
|
250 |
+
bottom_right = (bound["vertices"][2]["x"], bound["vertices"][2]["y"] + line_width_thick)
|
251 |
+
draw.line([bottom_left, bottom_right], fill=color, width=line_width_thick)
|
252 |
+
|
253 |
+
return self.image
|
254 |
+
|
255 |
+
|
256 |
+
def detect_text(self):
|
257 |
+
client = vision.ImageAnnotatorClient()
|
258 |
+
with io.open(self.path, 'rb') as image_file:
|
259 |
+
content = image_file.read()
|
260 |
+
image = vision.Image(content=content)
|
261 |
+
response = client.document_text_detection(image=image)
|
262 |
+
texts = response.text_annotations
|
263 |
+
|
264 |
+
if response.error.message:
|
265 |
+
raise Exception(
|
266 |
+
'{}\nFor more info on error messages, check: '
|
267 |
+
'https://cloud.google.com/apis/design/errors'.format(
|
268 |
+
response.error.message))
|
269 |
+
|
270 |
+
bounds = []
|
271 |
+
bounds_word = []
|
272 |
+
text_to_box_mapping = []
|
273 |
+
bounds_flat = []
|
274 |
+
height_flat = []
|
275 |
+
confidences = []
|
276 |
+
characters = []
|
277 |
+
organized_text = ""
|
278 |
+
paragraph_count = 0
|
279 |
+
|
280 |
+
for text in texts[1:]:
|
281 |
+
vertices = [{"x": vertex.x, "y": vertex.y} for vertex in text.bounding_poly.vertices]
|
282 |
+
map_dict = {
|
283 |
+
"vertices": vertices,
|
284 |
+
"text": text.description
|
285 |
+
}
|
286 |
+
text_to_box_mapping.append(map_dict)
|
287 |
+
|
288 |
+
for page in response.full_text_annotation.pages:
|
289 |
+
for block in page.blocks:
|
290 |
+
# paragraph_count += 1
|
291 |
+
# organized_text += f'OCR_paragraph_{paragraph_count}:\n' # Add paragraph label
|
292 |
+
for paragraph in block.paragraphs:
|
293 |
+
|
294 |
+
avg_H_list = []
|
295 |
+
for word in paragraph.words:
|
296 |
+
Yw = max(vertex.y for vertex in word.bounding_box.vertices)
|
297 |
+
# Calculate the width of the word and divide by the number of symbols
|
298 |
+
word_length = max(vertex.x for vertex in word.bounding_box.vertices) - min(vertex.x for vertex in word.bounding_box.vertices)
|
299 |
+
num_symbols = len(word.symbols)
|
300 |
+
if num_symbols <= 3:
|
301 |
+
H = int(Yw - min(vertex.y for vertex in word.bounding_box.vertices))
|
302 |
+
else:
|
303 |
+
Yo = Yw - min(vertex.y for vertex in word.bounding_box.vertices)
|
304 |
+
X = word_length / num_symbols if num_symbols > 0 else 0
|
305 |
+
H = int(X+(Yo*0.1))
|
306 |
+
avg_H_list.append(H)
|
307 |
+
avg_H = int(mean(avg_H_list))
|
308 |
+
|
309 |
+
words_in_para = []
|
310 |
+
for word in paragraph.words:
|
311 |
+
# Get word-level bounding box
|
312 |
+
bound_word_dict = {
|
313 |
+
"vertices": [
|
314 |
+
{"x": vertex.x, "y": vertex.y} for vertex in word.bounding_box.vertices
|
315 |
+
]
|
316 |
+
}
|
317 |
+
bounds_word.append(bound_word_dict)
|
318 |
+
|
319 |
+
Y = max(vertex.y for vertex in word.bounding_box.vertices)
|
320 |
+
word_x_start = min(vertex.x for vertex in word.bounding_box.vertices)
|
321 |
+
word_x_end = max(vertex.x for vertex in word.bounding_box.vertices)
|
322 |
+
num_symbols = len(word.symbols)
|
323 |
+
symbol_width = (word_x_end - word_x_start) / num_symbols if num_symbols > 0 else 0
|
324 |
+
|
325 |
+
current_x_position = word_x_start
|
326 |
+
|
327 |
+
characters_ind = []
|
328 |
+
for symbol in word.symbols:
|
329 |
+
bound_dict = {
|
330 |
+
"vertices": [
|
331 |
+
{"x": vertex.x, "y": vertex.y} for vertex in symbol.bounding_box.vertices
|
332 |
+
]
|
333 |
+
}
|
334 |
+
bounds.append(bound_dict)
|
335 |
+
|
336 |
+
# Create flat bounds with adjusted x position
|
337 |
+
bounds_flat_dict = {
|
338 |
+
"vertices": [
|
339 |
+
{"x": current_x_position, "y": Y},
|
340 |
+
{"x": current_x_position + symbol_width, "y": Y}
|
341 |
+
]
|
342 |
+
}
|
343 |
+
bounds_flat.append(bounds_flat_dict)
|
344 |
+
current_x_position += symbol_width
|
345 |
+
|
346 |
+
height_flat.append(avg_H)
|
347 |
+
confidences.append(round(symbol.confidence, 4))
|
348 |
+
|
349 |
+
characters_ind.append(symbol.text)
|
350 |
+
characters.append(symbol.text)
|
351 |
+
|
352 |
+
words_in_para.append(''.join(characters_ind))
|
353 |
+
paragraph_text = ' '.join(words_in_para) # Join words in paragraph
|
354 |
+
organized_text += paragraph_text + ' ' #+ '\n'
|
355 |
+
|
356 |
+
# median_height = statistics.median(height_flat) if height_flat else 0
|
357 |
+
# median_heights = [median_height] * len(characters)
|
358 |
+
|
359 |
+
self.normal_cleaned_text = texts[0].description if texts else ''
|
360 |
+
self.normal_organized_text = organized_text
|
361 |
+
self.normal_bounds = bounds
|
362 |
+
self.normal_bounds_word = bounds_word
|
363 |
+
self.normal_text_to_box_mapping = text_to_box_mapping
|
364 |
+
self.normal_bounds_flat = bounds_flat
|
365 |
+
# self.normal_height = median_heights #height_flat
|
366 |
+
self.normal_height = height_flat
|
367 |
+
self.normal_confidences = confidences
|
368 |
+
self.normal_characters = characters
|
369 |
+
|
370 |
+
|
371 |
+
def detect_handwritten_ocr(self):
|
372 |
+
client = vision_beta.ImageAnnotatorClient()
|
373 |
+
with open(self.path, "rb") as image_file:
|
374 |
+
content = image_file.read()
|
375 |
+
|
376 |
+
image = vision_beta.Image(content=content)
|
377 |
+
image_context = vision_beta.ImageContext(language_hints=["en-t-i0-handwrit"])
|
378 |
+
response = client.document_text_detection(image=image, image_context=image_context)
|
379 |
+
texts = response.text_annotations
|
380 |
+
|
381 |
+
if response.error.message:
|
382 |
+
raise Exception(
|
383 |
+
"{}\nFor more info on error messages, check: "
|
384 |
+
"https://cloud.google.com/apis/design/errors".format(response.error.message)
|
385 |
+
)
|
386 |
+
|
387 |
+
bounds = []
|
388 |
+
bounds_word = []
|
389 |
+
bounds_flat = []
|
390 |
+
height_flat = []
|
391 |
+
confidences = []
|
392 |
+
characters = []
|
393 |
+
organized_text = ""
|
394 |
+
paragraph_count = 0
|
395 |
+
text_to_box_mapping = []
|
396 |
+
|
397 |
+
for text in texts[1:]:
|
398 |
+
vertices = [{"x": vertex.x, "y": vertex.y} for vertex in text.bounding_poly.vertices]
|
399 |
+
map_dict = {
|
400 |
+
"vertices": vertices,
|
401 |
+
"text": text.description
|
402 |
+
}
|
403 |
+
text_to_box_mapping.append(map_dict)
|
404 |
+
|
405 |
+
for page in response.full_text_annotation.pages:
|
406 |
+
for block in page.blocks:
|
407 |
+
# paragraph_count += 1
|
408 |
+
# organized_text += f'\nOCR_paragraph_{paragraph_count}:\n' # Add paragraph label
|
409 |
+
for paragraph in block.paragraphs:
|
410 |
+
|
411 |
+
avg_H_list = []
|
412 |
+
for word in paragraph.words:
|
413 |
+
Yw = max(vertex.y for vertex in word.bounding_box.vertices)
|
414 |
+
# Calculate the width of the word and divide by the number of symbols
|
415 |
+
word_length = max(vertex.x for vertex in word.bounding_box.vertices) - min(vertex.x for vertex in word.bounding_box.vertices)
|
416 |
+
num_symbols = len(word.symbols)
|
417 |
+
if num_symbols <= 3:
|
418 |
+
H = int(Yw - min(vertex.y for vertex in word.bounding_box.vertices))
|
419 |
+
else:
|
420 |
+
Yo = Yw - min(vertex.y for vertex in word.bounding_box.vertices)
|
421 |
+
X = word_length / num_symbols if num_symbols > 0 else 0
|
422 |
+
H = int(X+(Yo*0.1))
|
423 |
+
avg_H_list.append(H)
|
424 |
+
avg_H = int(mean(avg_H_list))
|
425 |
+
|
426 |
+
words_in_para = []
|
427 |
+
for word in paragraph.words:
|
428 |
+
# Get word-level bounding box
|
429 |
+
bound_word_dict = {
|
430 |
+
"vertices": [
|
431 |
+
{"x": vertex.x, "y": vertex.y} for vertex in word.bounding_box.vertices
|
432 |
+
]
|
433 |
+
}
|
434 |
+
bounds_word.append(bound_word_dict)
|
435 |
+
|
436 |
+
Y = max(vertex.y for vertex in word.bounding_box.vertices)
|
437 |
+
word_x_start = min(vertex.x for vertex in word.bounding_box.vertices)
|
438 |
+
word_x_end = max(vertex.x for vertex in word.bounding_box.vertices)
|
439 |
+
num_symbols = len(word.symbols)
|
440 |
+
symbol_width = (word_x_end - word_x_start) / num_symbols if num_symbols > 0 else 0
|
441 |
+
|
442 |
+
current_x_position = word_x_start
|
443 |
+
|
444 |
+
characters_ind = []
|
445 |
+
for symbol in word.symbols:
|
446 |
+
bound_dict = {
|
447 |
+
"vertices": [
|
448 |
+
{"x": vertex.x, "y": vertex.y} for vertex in symbol.bounding_box.vertices
|
449 |
+
]
|
450 |
+
}
|
451 |
+
bounds.append(bound_dict)
|
452 |
+
|
453 |
+
# Create flat bounds with adjusted x position
|
454 |
+
bounds_flat_dict = {
|
455 |
+
"vertices": [
|
456 |
+
{"x": current_x_position, "y": Y},
|
457 |
+
{"x": current_x_position + symbol_width, "y": Y}
|
458 |
+
]
|
459 |
+
}
|
460 |
+
bounds_flat.append(bounds_flat_dict)
|
461 |
+
current_x_position += symbol_width
|
462 |
+
|
463 |
+
height_flat.append(avg_H)
|
464 |
+
confidences.append(round(symbol.confidence, 4))
|
465 |
+
|
466 |
+
characters_ind.append(symbol.text)
|
467 |
+
characters.append(symbol.text)
|
468 |
+
|
469 |
+
words_in_para.append(''.join(characters_ind))
|
470 |
+
paragraph_text = ' '.join(words_in_para) # Join words in paragraph
|
471 |
+
organized_text += paragraph_text + ' ' #+ '\n'
|
472 |
+
|
473 |
+
# median_height = statistics.median(height_flat) if height_flat else 0
|
474 |
+
# median_heights = [median_height] * len(characters)
|
475 |
+
|
476 |
+
self.hand_cleaned_text = response.text_annotations[0].description if response.text_annotations else ''
|
477 |
+
self.hand_organized_text = organized_text
|
478 |
+
self.hand_bounds = bounds
|
479 |
+
self.hand_bounds_word = bounds_word
|
480 |
+
self.hand_bounds_flat = bounds_flat
|
481 |
+
self.hand_text_to_box_mapping = text_to_box_mapping
|
482 |
+
# self.hand_height = median_heights #height_flat
|
483 |
+
self.hand_height = height_flat
|
484 |
+
self.hand_confidences = confidences
|
485 |
+
self.hand_characters = characters
|
486 |
+
|
487 |
+
|
488 |
+
def process_image(self, do_create_OCR_helper_image, logger):
|
489 |
+
if self.OCR_option in ['normal', 'both']:
|
490 |
+
self.detect_text()
|
491 |
+
if self.OCR_option in ['hand', 'both']:
|
492 |
+
self.detect_handwritten_ocr()
|
493 |
+
if self.OCR_option not in ['normal', 'hand', 'both']:
|
494 |
+
self.OCR_option = 'both'
|
495 |
+
self.detect_text()
|
496 |
+
self.detect_handwritten_ocr()
|
497 |
+
|
498 |
+
### Optionally add trOCR to the self.OCR for additional context
|
499 |
+
self.OCR = self.detect_text_with_trOCR_using_google_bboxes(self.do_use_trOCR, logger)
|
500 |
+
|
501 |
+
if do_create_OCR_helper_image:
|
502 |
+
self.image = Image.open(self.path)
|
503 |
+
|
504 |
+
if self.OCR_option in ['normal', 'both']:
|
505 |
+
image_with_boxes_normal = self.draw_boxes('normal')
|
506 |
+
text_image_normal = self.render_text_on_black_image('normal')
|
507 |
+
self.merged_image_normal = self.merge_images(image_with_boxes_normal, text_image_normal)
|
508 |
+
|
509 |
+
if self.OCR_option in ['hand', 'both']:
|
510 |
+
image_with_boxes_hand = self.draw_boxes('hand')
|
511 |
+
text_image_hand = self.render_text_on_black_image('hand')
|
512 |
+
self.merged_image_hand = self.merge_images(image_with_boxes_hand, text_image_hand)
|
513 |
+
|
514 |
+
if self.do_use_trOCR:
|
515 |
+
text_image_trOCR = self.render_text_on_black_image('trOCR')
|
516 |
+
|
517 |
+
### Merge final overlay image
|
518 |
+
### [original, normal bboxes, normal text]
|
519 |
+
if self.OCR_option in ['normal']:
|
520 |
+
self.overlay_image = self.merge_images(Image.open(self.path), self.merged_image_normal)
|
521 |
+
### [original, hand bboxes, hand text]
|
522 |
+
elif self.OCR_option in ['hand']:
|
523 |
+
self.overlay_image = self.merge_images(Image.open(self.path), self.merged_image_hand)
|
524 |
+
### [original, normal bboxes, normal text, hand bboxes, hand text]
|
525 |
+
else:
|
526 |
+
self.overlay_image = self.merge_images(Image.open(self.path), self.merge_images(self.merged_image_normal, self.merged_image_hand))
|
527 |
+
|
528 |
+
if self.do_use_trOCR:
|
529 |
+
self.overlay_image = self.merge_images(self.overlay_image, text_image_trOCR)
|
530 |
+
|
531 |
+
else:
|
532 |
+
self.merged_image_normal = None
|
533 |
+
self.merged_image_hand = None
|
534 |
+
self.overlay_image = Image.open(self.path)
|
535 |
+
|
536 |
+
|
537 |
+
'''
|
538 |
+
BBOX_COLOR = "black" # green cyan
|
539 |
+
|
540 |
+
def render_text_on_black_image(image_path, handwritten_char_bounds_flat, handwritten_char_confidences, handwritten_char_heights, characters):
|
541 |
+
# Load the original image to get its dimensions
|
542 |
+
original_image = Image.open(image_path)
|
543 |
+
width, height = original_image.size
|
544 |
+
|
545 |
+
# Create a black image of the same size
|
546 |
+
black_image = Image.new("RGB", (width, height), "black")
|
547 |
+
draw = ImageDraw.Draw(black_image)
|
548 |
+
|
549 |
+
# Loop through each character
|
550 |
+
for bound, confidence, char_height, character in zip(handwritten_char_bounds_flat, handwritten_char_confidences, handwritten_char_heights, characters):
|
551 |
+
# Determine the font size based on the height of the character
|
552 |
+
font_size = int(char_height)
|
553 |
+
font = ImageFont.load_default().font_variant(size=font_size)
|
554 |
+
|
555 |
+
# Color of the character
|
556 |
+
color = confidence_to_color(confidence)
|
557 |
+
|
558 |
+
# Position of the text (using the bottom-left corner of the bounding box)
|
559 |
+
position = (bound["vertices"][0]["x"], bound["vertices"][0]["y"] - char_height)
|
560 |
+
|
561 |
+
# Draw the character
|
562 |
+
draw.text(position, character, fill=color, font=font)
|
563 |
+
|
564 |
+
return black_image
|
565 |
+
|
566 |
+
def merge_images(image1, image2):
|
567 |
+
# Assuming both images are of the same size
|
568 |
+
width, height = image1.size
|
569 |
+
merged_image = Image.new("RGB", (width * 2, height))
|
570 |
+
merged_image.paste(image1, (0, 0))
|
571 |
+
merged_image.paste(image2, (width, 0))
|
572 |
+
return merged_image
|
573 |
+
|
574 |
def draw_boxes(image, bounds, color):
|
575 |
if bounds:
|
576 |
draw = ImageDraw.Draw(image)
|
|
|
590 |
)
|
591 |
return image
|
592 |
|
593 |
+
def detect_text(path):
|
594 |
+
client = vision.ImageAnnotatorClient()
|
595 |
with io.open(path, 'rb') as image_file:
|
596 |
content = image_file.read()
|
597 |
image = vision.Image(content=content)
|
|
|
624 |
else:
|
625 |
return '', None, None
|
626 |
|
627 |
+
def confidence_to_color(confidence):
|
628 |
+
"""Convert confidence level to a color ranging from red (low confidence) to green (high confidence)."""
|
629 |
+
# Using HSL color space, where Hue varies from red to green
|
630 |
+
hue = (confidence - 0.5) * 120 / 0.5 # Scale confidence to range 0-120 (red to green in HSL)
|
631 |
+
r, g, b = colorsys.hls_to_rgb(hue/360, 0.5, 1) # Convert to RGB
|
632 |
+
return (int(r*255), int(g*255), int(b*255))
|
633 |
+
|
634 |
+
def overlay_boxes_on_image(path, typed_bounds, handwritten_char_bounds, handwritten_char_confidences, do_create_OCR_helper_image):
|
635 |
if do_create_OCR_helper_image:
|
636 |
image = Image.open(path)
|
637 |
+
draw = ImageDraw.Draw(image)
|
638 |
+
width, height = image.size
|
639 |
+
line_width = int((width + height) / 2 * 0.005) # Adjust line width for character level
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
640 |
|
641 |
+
# Draw boxes for typed text
|
642 |
+
for bound in typed_bounds:
|
643 |
+
draw.polygon(
|
644 |
+
[
|
645 |
+
bound["vertices"][0]["x"], bound["vertices"][0]["y"],
|
646 |
+
bound["vertices"][1]["x"], bound["vertices"][1]["y"],
|
647 |
+
bound["vertices"][2]["x"], bound["vertices"][2]["y"],
|
648 |
+
bound["vertices"][3]["x"], bound["vertices"][3]["y"],
|
649 |
+
],
|
650 |
+
outline=BBOX_COLOR,
|
651 |
+
width=1
|
652 |
+
)
|
653 |
|
654 |
+
# Draw a line segment at the bottom of each handwritten character
|
655 |
+
for bound, confidence in zip(handwritten_char_bounds, handwritten_char_confidences):
|
656 |
+
color = confidence_to_color(confidence)
|
657 |
+
# Use the bottom two vertices of the bounding box for the line
|
658 |
+
bottom_left = (bound["vertices"][3]["x"], bound["vertices"][3]["y"] + line_width)
|
659 |
+
bottom_right = (bound["vertices"][2]["x"], bound["vertices"][2]["y"] + line_width)
|
660 |
+
draw.line([bottom_left, bottom_right], fill=color, width=line_width)
|
661 |
|
662 |
+
text_image = render_text_on_black_image(path, handwritten_char_bounds, handwritten_char_confidences)
|
663 |
+
merged_image = merge_images(image, text_image) # Assuming 'overlayed_image' is the image with lines
|
664 |
|
665 |
|
666 |
+
return merged_image
|
667 |
+
else:
|
668 |
+
return Image.open(path)
|
669 |
+
|
670 |
+
def detect_handwritten_ocr(path):
|
671 |
+
"""Detects handwritten characters in a local image and returns their bounding boxes and confidence levels.
|
672 |
|
673 |
+
Args:
|
674 |
+
path: The path to the local file.
|
675 |
|
676 |
+
Returns:
|
677 |
+
A tuple of (text, bounding_boxes, confidences)
|
678 |
+
"""
|
679 |
+
client = vision_beta.ImageAnnotatorClient()
|
680 |
|
681 |
+
with open(path, "rb") as image_file:
|
682 |
+
content = image_file.read()
|
683 |
|
684 |
+
image = vision_beta.Image(content=content)
|
685 |
+
image_context = vision_beta.ImageContext(language_hints=["en-t-i0-handwrit"])
|
686 |
+
response = client.document_text_detection(image=image, image_context=image_context)
|
687 |
|
688 |
+
if response.error.message:
|
689 |
+
raise Exception(
|
690 |
+
"{}\nFor more info on error messages, check: "
|
691 |
+
"https://cloud.google.com/apis/design/errors".format(response.error.message)
|
692 |
+
)
|
693 |
|
694 |
+
bounds = []
|
695 |
+
bounds_flat = []
|
696 |
+
height_flat = []
|
697 |
+
confidences = []
|
698 |
+
character = []
|
699 |
+
for page in response.full_text_annotation.pages:
|
700 |
+
for block in page.blocks:
|
701 |
+
for paragraph in block.paragraphs:
|
702 |
+
for word in paragraph.words:
|
703 |
+
# Get the bottom Y-location (max Y) for the whole word
|
704 |
+
Y = max(vertex.y for vertex in word.bounding_box.vertices)
|
705 |
+
|
706 |
+
# Get the height of the word's bounding box
|
707 |
+
H = Y - min(vertex.y for vertex in word.bounding_box.vertices)
|
708 |
+
|
709 |
+
for symbol in word.symbols:
|
710 |
+
# Collecting bounding box for each symbol
|
711 |
+
bound_dict = {
|
712 |
+
"vertices": [
|
713 |
+
{"x": vertex.x, "y": vertex.y} for vertex in symbol.bounding_box.vertices
|
714 |
+
]
|
715 |
+
}
|
716 |
+
bounds.append(bound_dict)
|
717 |
+
|
718 |
+
# Bounds with same bottom y height
|
719 |
+
bounds_flat_dict = {
|
720 |
+
"vertices": [
|
721 |
+
{"x": vertex.x, "y": Y} for vertex in symbol.bounding_box.vertices
|
722 |
+
]
|
723 |
+
}
|
724 |
+
bounds_flat.append(bounds_flat_dict)
|
725 |
+
|
726 |
+
# Add the word's height
|
727 |
+
height_flat.append(H)
|
728 |
+
|
729 |
+
# Collecting confidence for each symbol
|
730 |
+
symbol_confidence = round(symbol.confidence, 4)
|
731 |
+
confidences.append(symbol_confidence)
|
732 |
+
character.append(symbol.text)
|
733 |
+
|
734 |
+
cleaned_text = response.full_text_annotation.text
|
735 |
+
|
736 |
+
return cleaned_text, bounds, bounds_flat, height_flat, confidences, character
|
737 |
+
|
738 |
+
|
739 |
+
|
740 |
+
def process_image(path, do_create_OCR_helper_image):
|
741 |
+
typed_text, typed_bounds, _ = detect_text(path)
|
742 |
+
handwritten_text, handwritten_bounds, _ = detect_handwritten_ocr(path)
|
743 |
+
|
744 |
+
overlayed_image = overlay_boxes_on_image(path, typed_bounds, handwritten_bounds, do_create_OCR_helper_image)
|
745 |
+
return typed_text, handwritten_text, overlayed_image
|
746 |
+
|
747 |
+
'''
|
748 |
|
749 |
# ''' Google Vision'''
|
750 |
# def detect_text(path):
|
vouchervision/OCR_trOCR.py
ADDED
File without changes
|
vouchervision/VoucherVision_Config_Builder.py
CHANGED
@@ -4,58 +4,103 @@ from vouchervision.general_utils import validate_dir, print_main_fail
|
|
4 |
from vouchervision.vouchervision_main import voucher_vision
|
5 |
from general_utils import get_cfg_from_full_path
|
6 |
|
7 |
-
def build_VV_config():
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
17 |
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
dir_images_local = os.path.join(dir_home,'demo','demo_images')
|
22 |
-
|
23 |
-
# The default output location is the computer's "Downloads" folder
|
24 |
-
# You can set dir_output directly by typing the folder path,
|
25 |
-
# OR you can uncomment the line "dir_output = default_output_folder"
|
26 |
-
# to have VoucherVision save to the Downloads folder by default
|
27 |
-
default_output_folder = get_default_download_folder()
|
28 |
-
dir_output = default_output_folder
|
29 |
-
# dir_output = 'D:/D_Desktop/LM2'
|
30 |
-
|
31 |
-
prefix_removal = '' #'MICH-V-'
|
32 |
-
suffix_removal = ''
|
33 |
-
catalog_numerical_only = False
|
34 |
|
35 |
-
|
36 |
-
|
37 |
-
use_LeafMachine2_collage_images = False # Use LeafMachine2 collage images
|
38 |
-
do_create_OCR_helper_image = False
|
39 |
|
40 |
-
|
|
|
|
|
|
|
41 |
|
42 |
-
|
43 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
44 |
|
45 |
-
#############################################
|
46 |
-
#############################################
|
47 |
-
########## DO NOT EDIT BELOW HERE ###########
|
48 |
-
#############################################
|
49 |
-
#############################################
|
50 |
-
return assemble_config(dir_home, run_name, dir_images_local,dir_output,
|
51 |
-
prefix_removal,suffix_removal,catalog_numerical_only,LLM_version_user,batch_size,
|
52 |
-
path_domain_knowledge,embeddings_database_name,use_LeafMachine2_collage_images,
|
53 |
-
prompt_version, do_create_OCR_helper_image, use_domain_knowledge=False)
|
54 |
|
55 |
def assemble_config(dir_home, run_name, dir_images_local,dir_output,
|
56 |
-
prefix_removal,suffix_removal,catalog_numerical_only,LLM_version_user,batch_size,
|
57 |
path_domain_knowledge,embeddings_database_name,use_LeafMachine2_collage_images,
|
58 |
-
prompt_version, do_create_OCR_helper_image_user,
|
|
|
59 |
|
60 |
|
61 |
# Initialize the base structure
|
@@ -65,7 +110,7 @@ def assemble_config(dir_home, run_name, dir_images_local,dir_output,
|
|
65 |
|
66 |
# Modular sections to be added to 'leafmachine'
|
67 |
do_section = {
|
68 |
-
'check_for_illegal_filenames':
|
69 |
'check_for_corrupt_images_make_vertical': True,
|
70 |
}
|
71 |
|
@@ -84,7 +129,7 @@ def assemble_config(dir_home, run_name, dir_images_local,dir_output,
|
|
84 |
'run_name': run_name,
|
85 |
'image_location': 'local',
|
86 |
'batch_size': batch_size,
|
87 |
-
'num_workers':
|
88 |
'dir_images_local': dir_images_local,
|
89 |
'continue_run_from_partial_xlsx': '',
|
90 |
'prefix_removal': prefix_removal,
|
@@ -97,6 +142,8 @@ def assemble_config(dir_home, run_name, dir_images_local,dir_output,
|
|
97 |
'prompt_version': prompt_version,
|
98 |
'delete_all_temps': False,
|
99 |
'delete_temps_keep_VVE': False,
|
|
|
|
|
100 |
}
|
101 |
|
102 |
modules_section = {
|
@@ -109,7 +156,7 @@ def assemble_config(dir_home, run_name, dir_images_local,dir_output,
|
|
109 |
|
110 |
cropped_components_section = {
|
111 |
'do_save_cropped_annotations': True,
|
112 |
-
'save_cropped_annotations':
|
113 |
'save_per_image': False,
|
114 |
'save_per_annotation_class': True,
|
115 |
'binarize_labels': False,
|
@@ -238,7 +285,7 @@ def build_api_tests(api):
|
|
238 |
config_data, dir_home = assemble_config(dir_home, run_name, dir_images_local,dir_output,
|
239 |
prefix_removal,suffix_removal,catalog_numerical_only,LLM_version_user,batch_size,
|
240 |
path_domain_knowledge,embeddings_database_name,use_LeafMachine2_collage_images,
|
241 |
-
prompt_version,
|
242 |
|
243 |
write_config_file(config_data, os.path.join(dir_home,'demo','demo_configs'),filename=filename)
|
244 |
|
@@ -264,7 +311,6 @@ def build_demo_tests(llm_version):
|
|
264 |
batch_size = 500
|
265 |
do_create_OCR_helper_image = False
|
266 |
|
267 |
-
|
268 |
# ### Option 1: "GPT 4" of ["GPT 4", "GPT 3.5", "Azure GPT 4", "Azure GPT 3.5", "PaLM 2"]
|
269 |
# LLM_version_user = 'Azure GPT 4'
|
270 |
|
@@ -340,7 +386,7 @@ def build_demo_tests(llm_version):
|
|
340 |
return dir_home, path_to_configs, test_results
|
341 |
|
342 |
class TestOptionsGPT:
|
343 |
-
OPT1 = ["GPT 4", "GPT 3.5", "Azure GPT 4", "Azure GPT 3.5"]
|
344 |
OPT2 = [False, True]
|
345 |
OPT3 = ["Version 1", "Version 1 No Domain Knowledge", "Version 2"]
|
346 |
|
@@ -473,7 +519,7 @@ def run_demo_tests_Palm(progress_report):
|
|
473 |
|
474 |
if check_API_key(dir_home, api_version) and check_API_key(dir_home, 'google-vision-ocr') :
|
475 |
try:
|
476 |
-
last_JSON_response, total_cost = voucher_vision(cfg_file_path, dir_home, cfg_test=None, progress_report=progress_report, test_ind=int(test_ind))
|
477 |
test_results[cfg] = True
|
478 |
JSON_results[ind] = last_JSON_response
|
479 |
except Exception as e:
|
@@ -518,7 +564,7 @@ def run_api_tests(api):
|
|
518 |
|
519 |
if check_API_key(dir_home, api) and check_API_key(dir_home, 'google-vision-ocr') :
|
520 |
try:
|
521 |
-
last_JSON_response, total_cost = voucher_vision(cfg_file_path, dir_home, None, cfg_test=None, progress_report=None, test_ind=int(test_ind))
|
522 |
test_results[cfg] = True
|
523 |
JSON_results[ind] = last_JSON_response
|
524 |
return True
|
@@ -532,76 +578,50 @@ def run_api_tests(api):
|
|
532 |
print(e)
|
533 |
return False
|
534 |
|
535 |
-
|
536 |
-
|
537 |
-
|
538 |
-
|
539 |
-
|
540 |
-
def has_API_key(key_name):
|
541 |
-
# Check if the environment variable by key_name is not None
|
542 |
-
return os.getenv(key_name) is not None
|
543 |
-
|
544 |
|
545 |
-
|
546 |
-
|
547 |
-
|
548 |
-
|
549 |
|
550 |
-
|
551 |
|
552 |
-
|
553 |
|
554 |
-
|
555 |
|
556 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
557 |
|
558 |
-
# if has_key_google_OCR and (has_key_azure_openai or has_key_openai or has_key_palm2):
|
559 |
-
# return True
|
560 |
-
# else:
|
561 |
-
# return False
|
562 |
-
def check_if_usable():
|
563 |
-
has_key_openai = os.getenv('OPENAI_API_KEY') is not None
|
564 |
-
has_key_palm2 = os.getenv('PALM_API_KEY') is not None
|
565 |
-
has_key_google_OCR = os.getenv('GOOGLE_APPLICATION_CREDENTIALS') is not None
|
566 |
-
|
567 |
-
return has_key_google_OCR and (has_key_openai or has_key_palm2)
|
568 |
-
|
569 |
-
# def check_API_key(dir_home, api_version):
|
570 |
-
# dir_home = os.path.dirname(os.path.dirname(__file__))
|
571 |
-
# path_cfg_private = os.path.join(dir_home, 'PRIVATE_DATA.yaml')
|
572 |
-
# cfg_private = get_cfg_from_full_path(path_cfg_private)
|
573 |
-
|
574 |
-
# has_key_openai = has_API_key(cfg_private['openai']['OPENAI_API_KEY'])
|
575 |
-
|
576 |
-
# has_key_azure_openai = has_API_key(cfg_private['openai_azure']['api_version'])
|
577 |
-
|
578 |
-
# has_key_palm2 = has_API_key(cfg_private['google_palm']['google_palm_api'])
|
579 |
-
|
580 |
-
# has_key_google_OCR = has_API_key(cfg_private['google_cloud']['path_json_file'])
|
581 |
-
|
582 |
-
# if api_version == 'palm' and has_key_palm2:
|
583 |
-
# return True
|
584 |
-
# elif api_version in ['gpt','openai'] and has_key_openai:
|
585 |
-
# return True
|
586 |
-
# elif api_version in ['gpt-azure', 'azure_openai'] and has_key_azure_openai:
|
587 |
-
# return True
|
588 |
-
# elif api_version == 'google-vision-ocr' and has_key_google_OCR:
|
589 |
-
# return True
|
590 |
-
# else:
|
591 |
-
# return False
|
592 |
-
def check_API_key(api_version):
|
593 |
-
# The API keys are assumed to be set in the environment variables
|
594 |
-
has_key_openai = os.getenv('OPENAI_API_KEY') is not None
|
595 |
-
has_key_palm2 = os.getenv('PALM') is not None
|
596 |
-
has_key_google_OCR = os.getenv('GOOGLE_APPLICATION_CREDENTIALS') is not None
|
597 |
-
|
598 |
-
# Depending on the api_version, check if the corresponding key is present
|
599 |
if api_version == 'palm' and has_key_palm2:
|
600 |
return True
|
601 |
-
elif api_version in ['gpt',
|
|
|
|
|
602 |
return True
|
603 |
elif api_version == 'google-vision-ocr' and has_key_google_OCR:
|
604 |
return True
|
605 |
else:
|
606 |
return False
|
607 |
-
|
|
|
4 |
from vouchervision.vouchervision_main import voucher_vision
|
5 |
from general_utils import get_cfg_from_full_path
|
6 |
|
7 |
+
def build_VV_config(loaded_cfg=None):
|
8 |
+
if loaded_cfg is None:
|
9 |
+
#############################################
|
10 |
+
############ Set common defaults ############
|
11 |
+
#############################################
|
12 |
+
# Changing the values below will set new
|
13 |
+
# default values each time you open the
|
14 |
+
# VoucherVision user interface
|
15 |
+
#############################################
|
16 |
+
#############################################
|
17 |
+
#############################################
|
18 |
+
|
19 |
+
dir_home = os.path.dirname(os.path.dirname(__file__))
|
20 |
+
run_name = 'test'
|
21 |
+
# dir_images_local = 'D:/Dropbox/LM2_Env/Image_Datasets/GBIF_BroadSample_3SppPerFamily1'
|
22 |
+
dir_images_local = os.path.join(dir_home,'demo','demo_images')
|
23 |
+
|
24 |
+
# The default output location is the computer's "Downloads" folder
|
25 |
+
# You can set dir_output directly by typing the folder path,
|
26 |
+
# OR you can uncomment the line "dir_output = default_output_folder"
|
27 |
+
# to have VoucherVision save to the Downloads folder by default
|
28 |
+
default_output_folder = get_default_download_folder()
|
29 |
+
dir_output = default_output_folder
|
30 |
+
# dir_output = 'D:/D_Desktop/LM2'
|
31 |
+
|
32 |
+
prefix_removal = '' #'MICH-V-'
|
33 |
+
suffix_removal = ''
|
34 |
+
catalog_numerical_only = False
|
35 |
+
|
36 |
+
save_cropped_annotations = ['label','barcode']
|
37 |
+
|
38 |
+
do_use_trOCR = False
|
39 |
+
OCR_option = 'hand'
|
40 |
+
check_for_illegal_filenames = False
|
41 |
+
|
42 |
+
LLM_version_user = 'Azure GPT 4 Turbo 1106-preview' #'Azure GPT 4 Turbo 1106-preview'
|
43 |
+
prompt_version = 'version_5.yaml' # from ["Version 1", "Version 1 No Domain Knowledge", "Version 2"]
|
44 |
+
use_LeafMachine2_collage_images = True # Use LeafMachine2 collage images
|
45 |
+
do_create_OCR_helper_image = True
|
46 |
+
|
47 |
+
batch_size = 500
|
48 |
+
num_workers = 8
|
49 |
+
|
50 |
+
path_domain_knowledge = os.path.join(dir_home,'domain_knowledge','SLTP_UM_AllAsiaMinimalInRegion.xlsx')
|
51 |
+
embeddings_database_name = os.path.splitext(os.path.basename(path_domain_knowledge))[0]
|
52 |
+
|
53 |
+
#############################################
|
54 |
+
#############################################
|
55 |
+
########## DO NOT EDIT BELOW HERE ###########
|
56 |
+
#############################################
|
57 |
+
#############################################
|
58 |
+
return assemble_config(dir_home, run_name, dir_images_local,dir_output,
|
59 |
+
prefix_removal,suffix_removal,catalog_numerical_only,LLM_version_user,batch_size,num_workers,
|
60 |
+
path_domain_knowledge,embeddings_database_name,use_LeafMachine2_collage_images,
|
61 |
+
prompt_version, do_create_OCR_helper_image, do_use_trOCR, OCR_option, save_cropped_annotations,
|
62 |
+
check_for_illegal_filenames, use_domain_knowledge=False)
|
63 |
+
else:
|
64 |
+
dir_home = os.path.dirname(os.path.dirname(__file__))
|
65 |
+
run_name = loaded_cfg['leafmachine']['project']['run_name']
|
66 |
+
dir_images_local = loaded_cfg['leafmachine']['project']['dir_images_local']
|
67 |
+
|
68 |
+
default_output_folder = loaded_cfg['leafmachine']['project']['dir_output']
|
69 |
+
dir_output = loaded_cfg['leafmachine']['project']['dir_output']
|
70 |
|
71 |
+
prefix_removal = loaded_cfg['leafmachine']['project']['prefix_removal']
|
72 |
+
suffix_removal = loaded_cfg['leafmachine']['project']['suffix_removal']
|
73 |
+
catalog_numerical_only = loaded_cfg['leafmachine']['project']['catalog_numerical_only']
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
74 |
|
75 |
+
do_use_trOCR = loaded_cfg['leafmachine']['project']['do_use_trOCR']
|
76 |
+
OCR_option = loaded_cfg['leafmachine']['project']['OCR_option']
|
|
|
|
|
77 |
|
78 |
+
LLM_version_user = loaded_cfg['leafmachine']['LLM_version']
|
79 |
+
prompt_version = loaded_cfg['leafmachine']['project']['prompt_version']
|
80 |
+
use_LeafMachine2_collage_images = loaded_cfg['leafmachine']['use_RGB_label_images']
|
81 |
+
do_create_OCR_helper_image = loaded_cfg['leafmachine']['do_create_OCR_helper_image']
|
82 |
|
83 |
+
batch_size = loaded_cfg['leafmachine']['project']['batch_size']
|
84 |
+
num_workers = loaded_cfg['leafmachine']['project']['num_workers']
|
85 |
+
|
86 |
+
path_domain_knowledge = loaded_cfg['leafmachine']['project']['path_to_domain_knowledge_xlsx']
|
87 |
+
embeddings_database_name = os.path.splitext(os.path.basename(path_domain_knowledge))[0]
|
88 |
+
|
89 |
+
save_cropped_annotations = loaded_cfg['leafmachine']['cropped_components']['save_cropped_annotations']
|
90 |
+
check_for_illegal_filenames = loaded_cfg['leafmachine']['do']['check_for_illegal_filenames']
|
91 |
+
|
92 |
+
return assemble_config(dir_home, run_name, dir_images_local,dir_output,
|
93 |
+
prefix_removal,suffix_removal,catalog_numerical_only,LLM_version_user,batch_size,num_workers,
|
94 |
+
path_domain_knowledge,embeddings_database_name,use_LeafMachine2_collage_images,
|
95 |
+
prompt_version, do_create_OCR_helper_image, do_use_trOCR, OCR_option, save_cropped_annotations,
|
96 |
+
check_for_illegal_filenames, use_domain_knowledge=False)
|
97 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
98 |
|
99 |
def assemble_config(dir_home, run_name, dir_images_local,dir_output,
|
100 |
+
prefix_removal,suffix_removal,catalog_numerical_only,LLM_version_user,batch_size,num_workers,
|
101 |
path_domain_knowledge,embeddings_database_name,use_LeafMachine2_collage_images,
|
102 |
+
prompt_version, do_create_OCR_helper_image_user, do_use_trOCR, OCR_option, save_cropped_annotations,
|
103 |
+
check_for_illegal_filenames, use_domain_knowledge=False):
|
104 |
|
105 |
|
106 |
# Initialize the base structure
|
|
|
110 |
|
111 |
# Modular sections to be added to 'leafmachine'
|
112 |
do_section = {
|
113 |
+
'check_for_illegal_filenames': check_for_illegal_filenames,
|
114 |
'check_for_corrupt_images_make_vertical': True,
|
115 |
}
|
116 |
|
|
|
129 |
'run_name': run_name,
|
130 |
'image_location': 'local',
|
131 |
'batch_size': batch_size,
|
132 |
+
'num_workers': num_workers,
|
133 |
'dir_images_local': dir_images_local,
|
134 |
'continue_run_from_partial_xlsx': '',
|
135 |
'prefix_removal': prefix_removal,
|
|
|
142 |
'prompt_version': prompt_version,
|
143 |
'delete_all_temps': False,
|
144 |
'delete_temps_keep_VVE': False,
|
145 |
+
'do_use_trOCR': do_use_trOCR,
|
146 |
+
'OCR_option': OCR_option,
|
147 |
}
|
148 |
|
149 |
modules_section = {
|
|
|
156 |
|
157 |
cropped_components_section = {
|
158 |
'do_save_cropped_annotations': True,
|
159 |
+
'save_cropped_annotations': save_cropped_annotations,
|
160 |
'save_per_image': False,
|
161 |
'save_per_annotation_class': True,
|
162 |
'binarize_labels': False,
|
|
|
285 |
config_data, dir_home = assemble_config(dir_home, run_name, dir_images_local,dir_output,
|
286 |
prefix_removal,suffix_removal,catalog_numerical_only,LLM_version_user,batch_size,
|
287 |
path_domain_knowledge,embeddings_database_name,use_LeafMachine2_collage_images,
|
288 |
+
prompt_version,do_create_OCR_helper_image)
|
289 |
|
290 |
write_config_file(config_data, os.path.join(dir_home,'demo','demo_configs'),filename=filename)
|
291 |
|
|
|
311 |
batch_size = 500
|
312 |
do_create_OCR_helper_image = False
|
313 |
|
|
|
314 |
# ### Option 1: "GPT 4" of ["GPT 4", "GPT 3.5", "Azure GPT 4", "Azure GPT 3.5", "PaLM 2"]
|
315 |
# LLM_version_user = 'Azure GPT 4'
|
316 |
|
|
|
386 |
return dir_home, path_to_configs, test_results
|
387 |
|
388 |
class TestOptionsGPT:
|
389 |
+
OPT1 = ["gpt-4-1106-preview","GPT 4", "GPT 3.5", "Azure GPT 4", "Azure GPT 3.5"]
|
390 |
OPT2 = [False, True]
|
391 |
OPT3 = ["Version 1", "Version 1 No Domain Knowledge", "Version 2"]
|
392 |
|
|
|
519 |
|
520 |
if check_API_key(dir_home, api_version) and check_API_key(dir_home, 'google-vision-ocr') :
|
521 |
try:
|
522 |
+
last_JSON_response, total_cost = voucher_vision(cfg_file_path, dir_home, cfg_test=None, path_custom_prompts=None, progress_report=progress_report, test_ind=int(test_ind))
|
523 |
test_results[cfg] = True
|
524 |
JSON_results[ind] = last_JSON_response
|
525 |
except Exception as e:
|
|
|
564 |
|
565 |
if check_API_key(dir_home, api) and check_API_key(dir_home, 'google-vision-ocr') :
|
566 |
try:
|
567 |
+
last_JSON_response, total_cost = voucher_vision(cfg_file_path, dir_home, None,path_custom_prompts=None , cfg_test=None, progress_report=None, test_ind=int(test_ind))
|
568 |
test_results[cfg] = True
|
569 |
JSON_results[ind] = last_JSON_response
|
570 |
return True
|
|
|
578 |
print(e)
|
579 |
return False
|
580 |
|
581 |
+
def has_API_key(val):
|
582 |
+
if val != '':
|
583 |
+
return True
|
584 |
+
else:
|
585 |
+
return False
|
|
|
|
|
|
|
|
|
586 |
|
587 |
+
def check_if_usable():
|
588 |
+
dir_home = os.path.dirname(os.path.dirname(__file__))
|
589 |
+
path_cfg_private = os.path.join(dir_home, 'PRIVATE_DATA.yaml')
|
590 |
+
cfg_private = get_cfg_from_full_path(path_cfg_private)
|
591 |
|
592 |
+
has_key_openai = has_API_key(cfg_private['openai']['OPENAI_API_KEY'])
|
593 |
|
594 |
+
has_key_azure_openai = has_API_key(cfg_private['openai_azure']['api_version'])
|
595 |
|
596 |
+
has_key_palm2 = has_API_key(cfg_private['google_palm']['google_palm_api'])
|
597 |
|
598 |
+
has_key_google_OCR = has_API_key(cfg_private['google_cloud']['path_json_file'])
|
599 |
+
|
600 |
+
if has_key_google_OCR and (has_key_azure_openai or has_key_openai or has_key_palm2):
|
601 |
+
return True
|
602 |
+
else:
|
603 |
+
return False
|
604 |
+
|
605 |
+
def check_API_key(dir_home, api_version):
|
606 |
+
dir_home = os.path.dirname(os.path.dirname(__file__))
|
607 |
+
path_cfg_private = os.path.join(dir_home, 'PRIVATE_DATA.yaml')
|
608 |
+
cfg_private = get_cfg_from_full_path(path_cfg_private)
|
609 |
+
|
610 |
+
has_key_openai = has_API_key(cfg_private['openai']['OPENAI_API_KEY'])
|
611 |
+
|
612 |
+
has_key_azure_openai = has_API_key(cfg_private['openai_azure']['api_version'])
|
613 |
+
|
614 |
+
has_key_palm2 = has_API_key(cfg_private['google_palm']['google_palm_api'])
|
615 |
+
|
616 |
+
has_key_google_OCR = has_API_key(cfg_private['google_cloud']['path_json_file'])
|
617 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
618 |
if api_version == 'palm' and has_key_palm2:
|
619 |
return True
|
620 |
+
elif api_version in ['gpt','openai'] and has_key_openai:
|
621 |
+
return True
|
622 |
+
elif api_version in ['gpt-azure', 'azure_openai'] and has_key_azure_openai:
|
623 |
return True
|
624 |
elif api_version == 'google-vision-ocr' and has_key_google_OCR:
|
625 |
return True
|
626 |
else:
|
627 |
return False
|
|
vouchervision/VoucherVision_GUI.py
ADDED
The diff for this file is too large to render.
See raw diff
|
|
vouchervision/embed_occ.py
CHANGED
@@ -15,7 +15,7 @@ parentdir = os.path.dirname(currentdir)
|
|
15 |
sys.path.append(parentdir)
|
16 |
from vouchervision.general_utils import get_cfg_from_full_path
|
17 |
from prompts import PROMPT_UMICH_skeleton_all_asia
|
18 |
-
from
|
19 |
|
20 |
'''
|
21 |
This generates OpenAI embedding. These are no longer used by VoucherVision.
|
|
|
15 |
sys.path.append(parentdir)
|
16 |
from vouchervision.general_utils import get_cfg_from_full_path
|
17 |
from prompts import PROMPT_UMICH_skeleton_all_asia
|
18 |
+
from vouchervision.LLM_OpenAI import num_tokens_from_string, OCR_to_dict
|
19 |
|
20 |
'''
|
21 |
This generates OpenAI embedding. These are no longer used by VoucherVision.
|
vouchervision/embeddings_db.py
CHANGED
@@ -7,7 +7,7 @@ import chromadb
|
|
7 |
from chromadb.config import Settings
|
8 |
from chromadb.utils import embedding_functions
|
9 |
from InstructorEmbedding import INSTRUCTOR
|
10 |
-
from
|
11 |
'''
|
12 |
If there is a transformers install error:
|
13 |
pip install transformers==4.29.2
|
|
|
7 |
from chromadb.config import Settings
|
8 |
from chromadb.utils import embedding_functions
|
9 |
from InstructorEmbedding import INSTRUCTOR
|
10 |
+
from langchain_community.vectorstores import Chroma
|
11 |
'''
|
12 |
If there is a transformers install error:
|
13 |
pip install transformers==4.29.2
|
vouchervision/general_utils.py
CHANGED
@@ -9,6 +9,8 @@ import concurrent.futures
|
|
9 |
from time import perf_counter
|
10 |
import torch
|
11 |
|
|
|
|
|
12 |
'''
|
13 |
TIFF --> DNG
|
14 |
Install
|
@@ -21,10 +23,6 @@ https://helpx.adobe.com/content/dam/help/en/photoshop/pdf/dng_commandline.pdf
|
|
21 |
|
22 |
# https://stackoverflow.com/questions/287871/how-do-i-print-colored-text-to-the-terminal
|
23 |
|
24 |
-
def make_zipfile(source_dir, output_filename):
|
25 |
-
shutil.make_archive(output_filename, 'zip', source_dir)
|
26 |
-
return output_filename + '.zip'
|
27 |
-
|
28 |
def validate_dir(dir):
|
29 |
if not os.path.exists(dir):
|
30 |
os.makedirs(dir, exist_ok=True)
|
@@ -71,42 +69,42 @@ def add_to_expense_report(dir_home, data):
|
|
71 |
# Write the data row
|
72 |
writer.writerow(data)
|
73 |
|
74 |
-
def save_token_info_as_csv(Dirs, LLM_version0, path_api_cost, total_tokens_in, total_tokens_out, n_images):
|
75 |
-
|
76 |
-
|
77 |
-
'GPT 3.5': 'GPT_3_5',
|
78 |
-
'Azure GPT 3.5': 'GPT_3_5',
|
79 |
-
'Azure GPT 4': 'GPT_4',
|
80 |
-
'PaLM 2': 'PALM2'
|
81 |
-
}
|
82 |
-
LLM_version = version_mapping[LLM_version0]
|
83 |
-
# Define the CSV file path
|
84 |
-
csv_file_path = os.path.join(Dirs.path_cost, Dirs.run_name + '.csv')
|
85 |
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
# Open the file in write mode
|
92 |
-
with open(csv_file_path, mode='w', newline='') as file:
|
93 |
-
writer = csv.writer(file)
|
94 |
|
95 |
-
#
|
96 |
-
|
97 |
|
98 |
-
#
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
110 |
|
111 |
def summarize_expense_report(path_expense_report):
|
112 |
# Initialize counters and sums
|
@@ -275,7 +273,7 @@ def create_google_ocr_yaml_config(output_file, dir_images_local, dir_output):
|
|
275 |
}
|
276 |
# Generate the YAML string from the data structure
|
277 |
validate_dir(os.path.dirname(output_file))
|
278 |
-
yaml_str = yaml.dump(config)
|
279 |
|
280 |
# Write the YAML string to a file
|
281 |
with open(output_file, 'w') as file:
|
@@ -429,7 +427,7 @@ def save_config_file(cfg, logger, Dirs):
|
|
429 |
|
430 |
def write_yaml(cfg, path_cfg):
|
431 |
with open(path_cfg, 'w') as file:
|
432 |
-
yaml.dump(cfg, file)
|
433 |
|
434 |
def split_into_batches(Project, logger, cfg):
|
435 |
logger.name = 'Creating Batches'
|
|
|
9 |
from time import perf_counter
|
10 |
import torch
|
11 |
|
12 |
+
from vouchervision.model_maps import ModelMaps
|
13 |
+
|
14 |
'''
|
15 |
TIFF --> DNG
|
16 |
Install
|
|
|
23 |
|
24 |
# https://stackoverflow.com/questions/287871/how-do-i-print-colored-text-to-the-terminal
|
25 |
|
|
|
|
|
|
|
|
|
26 |
def validate_dir(dir):
|
27 |
if not os.path.exists(dir):
|
28 |
os.makedirs(dir, exist_ok=True)
|
|
|
69 |
# Write the data row
|
70 |
writer.writerow(data)
|
71 |
|
72 |
+
def save_token_info_as_csv(Dirs, LLM_version0, path_api_cost, total_tokens_in, total_tokens_out, n_images, dir_home, logger):
|
73 |
+
if path_api_cost:
|
74 |
+
LLM_version = ModelMaps.get_version_mapping_cost(LLM_version0)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
75 |
|
76 |
+
# Define the CSV file path
|
77 |
+
csv_file_path = os.path.join(Dirs.path_cost, Dirs.run_name + '.csv')
|
78 |
+
|
79 |
+
cost_in, cost_out, total_cost, rate_in, rate_out = calculate_cost(LLM_version, path_api_cost, total_tokens_in, total_tokens_out)
|
|
|
|
|
|
|
|
|
80 |
|
81 |
+
# The data to be written to the CSV file
|
82 |
+
data = [Dirs.run_name, get_datetime(),LLM_version, total_cost, n_images, total_tokens_in, total_tokens_out, rate_in, rate_out, cost_in, cost_out,]
|
83 |
|
84 |
+
# Open the file in write mode
|
85 |
+
with open(csv_file_path, mode='w', newline='') as file:
|
86 |
+
writer = csv.writer(file)
|
87 |
+
|
88 |
+
# Write the header
|
89 |
+
writer.writerow(['run','date','api_version','total_cost', 'n_images', 'tokens_in', 'tokens_out', 'rate_in', 'rate_out', 'cost_in', 'cost_out',])
|
90 |
+
|
91 |
+
# Write the data
|
92 |
+
writer.writerow(data)
|
93 |
+
# Create a summary string
|
94 |
+
cost_summary = (f"Cost Summary for {Dirs.run_name}:\n"
|
95 |
+
f" API Cost In: ${rate_in} per 1000 Tokens\n"
|
96 |
+
f" API Cost Out: ${rate_out} per 1000 Tokens\n"
|
97 |
+
f" Tokens In: {total_tokens_in} - Cost: ${cost_in:.4f}\n"
|
98 |
+
f" Tokens Out: {total_tokens_out} - Cost: ${cost_out:.4f}\n"
|
99 |
+
f" Images Processed: {n_images}\n"
|
100 |
+
f" Total Cost: ${total_cost:.4f}")
|
101 |
+
|
102 |
+
add_to_expense_report(dir_home, data)
|
103 |
+
logger.info(cost_summary)
|
104 |
+
return total_cost
|
105 |
+
|
106 |
+
else:
|
107 |
+
return None #TODO add config tests to expense_report
|
108 |
|
109 |
def summarize_expense_report(path_expense_report):
|
110 |
# Initialize counters and sums
|
|
|
273 |
}
|
274 |
# Generate the YAML string from the data structure
|
275 |
validate_dir(os.path.dirname(output_file))
|
276 |
+
yaml_str = yaml.dump(config, sort_keys=False)
|
277 |
|
278 |
# Write the YAML string to a file
|
279 |
with open(output_file, 'w') as file:
|
|
|
427 |
|
428 |
def write_yaml(cfg, path_cfg):
|
429 |
with open(path_cfg, 'w') as file:
|
430 |
+
yaml.dump(cfg, file, sort_keys=False)
|
431 |
|
432 |
def split_into_batches(Project, logger, cfg):
|
433 |
logger.name = 'Creating Batches'
|
vouchervision/model_maps.py
ADDED
@@ -0,0 +1,232 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
class ModelMaps:
|
2 |
+
PROMPTS_THAT_NEED_DOMAIN_KNOWLEDGE = ["Version 1", "Version 1 PaLM 2"]
|
3 |
+
COLORS_EXPENSE_REPORT = {
|
4 |
+
'GPT_4': '#32CD32', # Lime Green
|
5 |
+
'GPT_3_5': '#008000', # Green
|
6 |
+
'GPT_3_5_INSTRUCT': '#3CB371', # Medium Sea Green
|
7 |
+
'GPT_4_TURBO': '#228B22', # Forest Green
|
8 |
+
'GPT_4_32K': '#006400', # Dark Green
|
9 |
+
|
10 |
+
'PALM2_TB_1': '#87CEEB', # Sky Blue
|
11 |
+
'PALM2_TB_2': '#1E90FF', # Dodger Blue
|
12 |
+
'PALM2_TU_1': '#0000FF', # Blue
|
13 |
+
'GEMINI_PRO': '#1E00FF', #
|
14 |
+
|
15 |
+
'AZURE_GPT_4': '#800080', # Purple
|
16 |
+
'AZURE_GPT_4_TURBO': '#9370DB', # Medium Purple
|
17 |
+
'AZURE_GPT_4_32K': '#8A2BE2', # Blue Violet
|
18 |
+
'AZURE_GPT_3_5_INSTRUCT': '#9400D3', # Dark Violet
|
19 |
+
'AZURE_GPT_3_5': '#9932CC', # Dark Orchid
|
20 |
+
|
21 |
+
'MISTRAL_TINY': '#FFA07A', # Light Salmon
|
22 |
+
'MISTRAL_SMALL': '#FF8C00', # Dark Orange
|
23 |
+
'MISTRAL_MEDIUM': '#FF4500', # Orange Red
|
24 |
+
|
25 |
+
'LOCAL_MIXTRAL_8X7B_INSTRUCT_V01': '#000000', # Black
|
26 |
+
'LOCAL_MISTRAL_7B_INSTRUCT_V02': '#4a4a4a', # Gray
|
27 |
+
|
28 |
+
'LOCAL_CPU_MISTRAL_7B_INSTRUCT_V02_GGUF': '#bababa', # Gray
|
29 |
+
}
|
30 |
+
|
31 |
+
MODELS_OPENAI = ["GPT 4",
|
32 |
+
"GPT 4 32k",
|
33 |
+
"GPT 4 Turbo 1106-preview",
|
34 |
+
"GPT 3.5",
|
35 |
+
"GPT 3.5 Instruct",
|
36 |
+
|
37 |
+
"Azure GPT 4",
|
38 |
+
"Azure GPT 4 32k",
|
39 |
+
"Azure GPT 4 Turbo 1106-preview",
|
40 |
+
"Azure GPT 3.5",
|
41 |
+
"Azure GPT 3.5 Instruct",]
|
42 |
+
|
43 |
+
MODELS_GOOGLE = ["PaLM 2 text-bison@001",
|
44 |
+
"PaLM 2 text-bison@002",
|
45 |
+
"PaLM 2 text-unicorn@001",
|
46 |
+
"Gemini Pro"]
|
47 |
+
|
48 |
+
MODELS_MISTRAL = ["Mistral Tiny",
|
49 |
+
"Mistral Small",
|
50 |
+
"Mistral Medium",]
|
51 |
+
|
52 |
+
MODELS_LOCAL = ["LOCAL Mixtral 8x7B Instruct v0.1",
|
53 |
+
"LOCAL Mistral 7B Instruct v0.2",
|
54 |
+
"LOCAL CPU Mistral 7B Instruct v0.2 GGUF",]
|
55 |
+
|
56 |
+
MODELS_GUI_DEFAULT = "Azure GPT 3.5 Instruct" # "GPT 4 Turbo 1106-preview"
|
57 |
+
|
58 |
+
version_mapping_cost = {
|
59 |
+
'GPT 4 32k': 'GPT_4_32K',
|
60 |
+
'GPT 4': 'GPT_4',
|
61 |
+
'GPT 4 Turbo 1106-preview': 'GPT_4_TURBO',
|
62 |
+
'GPT 3.5 Instruct': 'GPT_3_5_INSTRUCT',
|
63 |
+
'GPT 3.5': 'GPT_3_5',
|
64 |
+
|
65 |
+
'Azure GPT 4 32k': 'AZURE_GPT_4_32K',
|
66 |
+
'Azure GPT 4': 'AZURE_GPT_4',
|
67 |
+
'Azure GPT 4 Turbo 1106-preview': 'AZURE_GPT_4_TURBO',
|
68 |
+
'Azure GPT 3.5 Instruct': 'AZURE_GPT_3_5_INSTRUCT',
|
69 |
+
'Azure GPT 3.5': 'AZURE_GPT_3_5',
|
70 |
+
|
71 |
+
'Gemini Pro': 'GEMINI_PRO',
|
72 |
+
'PaLM 2 text-unicorn@001': 'PALM2_TU_1',
|
73 |
+
'PaLM 2 text-bison@001': 'PALM2_TB_1',
|
74 |
+
'PaLM 2 text-bison@002': 'PALM2_TB_2',
|
75 |
+
|
76 |
+
'Mistral Medium': 'MISTRAL_MEDIUM',
|
77 |
+
'Mistral Small': 'MISTRAL_SMALL',
|
78 |
+
'Mistral Tiny': 'MISTRAL_TINY',
|
79 |
+
|
80 |
+
'LOCAL Mixtral 8x7B Instruct v0.1': 'LOCAL_MIXTRAL_8X7B_INSTRUCT_V01',
|
81 |
+
'LOCAL Mistral 7B Instruct v0.2': 'LOCAL_MISTRAL_7B_INSTRUCT_V02',
|
82 |
+
|
83 |
+
'LOCAL CPU Mistral 7B Instruct v0.2 GGUF': 'LOCAL_CPU_MISTRAL_7B_INSTRUCT_V02_GGUF',
|
84 |
+
}
|
85 |
+
|
86 |
+
@classmethod
|
87 |
+
def get_version_has_key(cls, key, has_key_openai, has_key_azure_openai, has_key_palm2, has_key_mistral):
|
88 |
+
# Define the mapping for 'has_key' values
|
89 |
+
version_has_key = {
|
90 |
+
'GPT 4 Turbo 1106-preview': has_key_openai,
|
91 |
+
'GPT 4': has_key_openai,
|
92 |
+
'GPT 4 32k': has_key_openai,
|
93 |
+
'GPT 3.5': has_key_openai,
|
94 |
+
'GPT 3.5 Instruct': has_key_openai,
|
95 |
+
|
96 |
+
'Azure GPT 3.5': has_key_azure_openai,
|
97 |
+
'Azure GPT 3.5 Instruct': has_key_azure_openai,
|
98 |
+
'Azure GPT 4': has_key_azure_openai,
|
99 |
+
'Azure GPT 4 Turbo 1106-preview': has_key_azure_openai,
|
100 |
+
'Azure GPT 4 32k': has_key_azure_openai,
|
101 |
+
|
102 |
+
'PaLM 2 text-bison@001': has_key_palm2,
|
103 |
+
'PaLM 2 text-bison@002': has_key_palm2,
|
104 |
+
'PaLM 2 text-unicorn@001': has_key_palm2,
|
105 |
+
'Gemini Pro': has_key_palm2,
|
106 |
+
|
107 |
+
'Mistral Tiny': has_key_mistral,
|
108 |
+
'Mistral Small': has_key_mistral,
|
109 |
+
'Mistral Medium': has_key_mistral,
|
110 |
+
|
111 |
+
'LOCAL Mixtral 8x7B Instruct v0.1': True,
|
112 |
+
'LOCAL Mistral 7B Instruct v0.2': True,
|
113 |
+
|
114 |
+
'LOCAL CPU Mistral 7B Instruct v0.2 GGUF': True,
|
115 |
+
}
|
116 |
+
return version_has_key.get(key)
|
117 |
+
|
118 |
+
@classmethod
|
119 |
+
def get_version_mapping_is_azure(cls, key):
|
120 |
+
version_mapping_is_azure = {
|
121 |
+
"GPT 4 Turbo 1106-preview": False,
|
122 |
+
'GPT 4': False,
|
123 |
+
'GPT 4 32k': False,
|
124 |
+
'GPT 3.5': False,
|
125 |
+
'GPT 3.5 Instruct': False,
|
126 |
+
|
127 |
+
'Azure GPT 3.5': True,
|
128 |
+
'Azure GPT 3.5 Instruct': True,
|
129 |
+
'Azure GPT 4': True,
|
130 |
+
'Azure GPT 4 Turbo 1106-preview': True,
|
131 |
+
'Azure GPT 4 32k': True,
|
132 |
+
|
133 |
+
'PaLM 2 text-bison@001': False,
|
134 |
+
'PaLM 2 text-bison@002': False,
|
135 |
+
'PaLM 2 text-unicorn@001': False,
|
136 |
+
'Gemini Pro': False,
|
137 |
+
|
138 |
+
'Mistral Tiny': False,
|
139 |
+
'Mistral Small': False,
|
140 |
+
'Mistral Medium': False,
|
141 |
+
|
142 |
+
'LOCAL Mixtral 8x7B Instruct v0.1': False,
|
143 |
+
'LOCAL Mistral 7B Instruct v0.2': False,
|
144 |
+
|
145 |
+
'LOCAL CPU Mistral 7B Instruct v0.2 GGUF': False,
|
146 |
+
}
|
147 |
+
return version_mapping_is_azure.get(key)
|
148 |
+
|
149 |
+
@classmethod
|
150 |
+
def get_API_name(cls, key):
|
151 |
+
|
152 |
+
### OpenAI
|
153 |
+
if key == 'GPT_3_5':
|
154 |
+
return 'gpt-3.5-turbo-1106'
|
155 |
+
|
156 |
+
elif key == 'GPT_3_5_INSTRUCT':
|
157 |
+
return 'gpt-3.5-turbo-instruct'
|
158 |
+
|
159 |
+
elif key == 'GPT_4':
|
160 |
+
return 'gpt-4'
|
161 |
+
|
162 |
+
elif key == 'GPT_4_32K':
|
163 |
+
return 'gpt-4-32k'
|
164 |
+
|
165 |
+
elif key == 'GPT_4_TURBO':
|
166 |
+
return 'gpt-4-1106-preview'
|
167 |
+
|
168 |
+
### Azure
|
169 |
+
elif key == 'AZURE_GPT_3_5':
|
170 |
+
return 'gpt-35-turbo-1106'
|
171 |
+
|
172 |
+
elif key == 'AZURE_GPT_3_5_INSTRUCT':
|
173 |
+
return 'gpt-35-turbo-instruct'
|
174 |
+
|
175 |
+
elif key == 'AZURE_GPT_4':
|
176 |
+
return "gpt-4"
|
177 |
+
|
178 |
+
elif key == 'AZURE_GPT_4_TURBO':
|
179 |
+
return "gpt-4-1106-preview"
|
180 |
+
|
181 |
+
elif key == 'AZURE_GPT_4_32K':
|
182 |
+
return "gpt-4-32k"
|
183 |
+
|
184 |
+
### Google
|
185 |
+
elif key == 'PALM2_TB_1':
|
186 |
+
return "text-bison@001"
|
187 |
+
|
188 |
+
elif key == 'PALM2_TB_2':
|
189 |
+
return "text-bison@002"
|
190 |
+
|
191 |
+
elif key == 'PALM2_TU_1':
|
192 |
+
return "text-unicorn@001"
|
193 |
+
|
194 |
+
elif key == 'GEMINI_PRO':
|
195 |
+
return "gemini-pro"
|
196 |
+
|
197 |
+
### Mistral
|
198 |
+
elif key == 'MISTRAL_TINY':
|
199 |
+
return "mistral-tiny"
|
200 |
+
|
201 |
+
elif key == 'MISTRAL_SMALL':
|
202 |
+
return 'mistral-small'
|
203 |
+
|
204 |
+
elif key == 'MISTRAL_MEDIUM':
|
205 |
+
return 'mistral-medium'
|
206 |
+
|
207 |
+
|
208 |
+
### Mistral LOCAL
|
209 |
+
elif key == 'LOCAL_MIXTRAL_8X7B_INSTRUCT_V01':
|
210 |
+
return 'Mixtral-8x7B-Instruct-v0.1'
|
211 |
+
|
212 |
+
elif key == 'LOCAL_MISTRAL_7B_INSTRUCT_V02':
|
213 |
+
return 'Mistral-7B-Instruct-v0.2'
|
214 |
+
|
215 |
+
### Mistral LOCAL CPU
|
216 |
+
elif key == 'LOCAL_CPU_MISTRAL_7B_INSTRUCT_V02_GGUF':
|
217 |
+
return 'Mistral-7B-Instruct-v0.2-GGUF'
|
218 |
+
|
219 |
+
else:
|
220 |
+
raise ValueError(f"Invalid model name {key}. See model_maps.py")
|
221 |
+
|
222 |
+
@classmethod
|
223 |
+
def get_models_gui_list(cls):
|
224 |
+
return cls.MODELS_LOCAL + cls.MODELS_GOOGLE + cls.MODELS_OPENAI + cls.MODELS_MISTRAL
|
225 |
+
|
226 |
+
@classmethod
|
227 |
+
def get_version_mapping_cost(cls, key):
|
228 |
+
return cls.version_mapping_cost.get(key, None)
|
229 |
+
|
230 |
+
@classmethod
|
231 |
+
def get_all_mapping_cost(cls):
|
232 |
+
return cls.version_mapping_cost
|
vouchervision/prompt_catalog.py
CHANGED
@@ -1,9 +1,7 @@
|
|
1 |
from dataclasses import dataclass
|
|
|
2 |
import yaml, json
|
3 |
|
4 |
-
|
5 |
-
# catalog = PromptCatalog(OCR="Sample OCR text", domain_knowledge_example="Sample domain knowledge", similarity="0.9")
|
6 |
-
|
7 |
@dataclass
|
8 |
class PromptCatalog:
|
9 |
domain_knowledge_example: str = ""
|
@@ -11,565 +9,6 @@ class PromptCatalog:
|
|
11 |
OCR: str = ""
|
12 |
n_fields: int = 0
|
13 |
|
14 |
-
# def PROMPT_UMICH_skeleton_all_asia(self, OCR=None, domain_knowledge_example=None, similarity=None):
|
15 |
-
def prompt_v1_verbose(self, OCR=None, domain_knowledge_example=None, similarity=None):
|
16 |
-
self.OCR = OCR or self.OCR
|
17 |
-
self.domain_knowledge_example = domain_knowledge_example or self.domain_knowledge_example
|
18 |
-
self.similarity = similarity or self.similarity
|
19 |
-
self.n_fields = 22 or self.n_fields
|
20 |
-
|
21 |
-
set_rules = """
|
22 |
-
Please note that your task is to generate a dictionary, following the below rules:
|
23 |
-
1. Refactor the unstructured OCR text into a dictionary based on the reference dictionary structure (ref_dict).
|
24 |
-
2. Each field of OCR corresponds to a column of the ref_dict. You should correctly map the values from OCR to the respective fields in ref_dict.
|
25 |
-
3. If the OCR is mostly empty and contains substantially less text than the ref_dict examples, then only return "None".
|
26 |
-
4. If there is a field in the ref_dict that does not have a corresponding value in the OCR text, fill it based on your knowledge but don't generate new information.
|
27 |
-
5. Do not use any text from the ref_dict values in the new dict, but you must use the headers from ref_dict.
|
28 |
-
6. Duplicate dictionary fields are not allowed.
|
29 |
-
7. Only return the new dictionary. You should not explain your answer.
|
30 |
-
8. Your output should be a Python dictionary represented as a JSON string.
|
31 |
-
"""
|
32 |
-
|
33 |
-
umich_all_asia_rules = """{
|
34 |
-
"Catalog Number": {
|
35 |
-
"format": "[Catalog Number]",
|
36 |
-
"null_value": "",
|
37 |
-
"description": "The barcode identifier, typically a number with at least 6 digits, but fewer than 30 digits"
|
38 |
-
},
|
39 |
-
"Genus": {
|
40 |
-
"format": "[Genus] or '[Family] indet' if no genus",
|
41 |
-
"null_value": "",
|
42 |
-
"description": "Taxonomic determination to genus, do capitalize genus"
|
43 |
-
},
|
44 |
-
"Species": {
|
45 |
-
"format": "[species] or 'indet' if no species",
|
46 |
-
"null_value": "",
|
47 |
-
"description": "Taxonomic determination to species, do not capitalize species"
|
48 |
-
},
|
49 |
-
"subspecies": {
|
50 |
-
"format": "[subspecies]",
|
51 |
-
"null_value": "",
|
52 |
-
"description": "Taxonomic determination to subspecies (subsp.)"
|
53 |
-
},
|
54 |
-
"variety": {
|
55 |
-
"format": "[variety]",
|
56 |
-
"null_value": "",
|
57 |
-
"description": "Taxonomic determination to variety (var)"
|
58 |
-
},
|
59 |
-
"forma": {
|
60 |
-
"format": "[form]",
|
61 |
-
"null_value": "",
|
62 |
-
"description": "Taxonomic determination to form (f.)"
|
63 |
-
},
|
64 |
-
"Country": {
|
65 |
-
"format": "[Country]",
|
66 |
-
"null_value": "",
|
67 |
-
"description": "Country that corresponds to the current geographic location of collection; capitalize first letter of each word; use the entire location name even if an abbreviation is given"
|
68 |
-
},
|
69 |
-
"State": {
|
70 |
-
"format": "[Adm. Division 1]",
|
71 |
-
"null_value": "",
|
72 |
-
"description": "Administrative division 1 that corresponds to the current geographic location of collection; capitalize first letter of each word"
|
73 |
-
},
|
74 |
-
"County": {
|
75 |
-
"format": "[Adm. Division 2]",
|
76 |
-
"null_value": "",
|
77 |
-
"description": "Administrative division 2 that corresponds to the current geographic location of collection; capitalize first letter of each word"
|
78 |
-
},
|
79 |
-
"Locality Name": {
|
80 |
-
"format": "verbatim, if no geographic info: 'no data provided on label of catalog no: [######]', or if illegible: 'locality present but illegible/not translated for catalog no: #######', or if no named locality: 'no named locality for catalog no: #######'",
|
81 |
-
"description": "Description of geographic location or landscape"
|
82 |
-
},
|
83 |
-
"Min Elevation": {
|
84 |
-
"format": "elevation integer",
|
85 |
-
"null_value": "",
|
86 |
-
"description": "Elevation or altitude in meters, convert from feet to meters if 'm' or 'meters' is not in the text and round to integer, default field for elevation if a range is not given"
|
87 |
-
},
|
88 |
-
"Max Elevation": {
|
89 |
-
"format": "elevation integer",
|
90 |
-
"null_value": "",
|
91 |
-
"description": "Elevation or altitude in meters, convert from feet to meters if 'm' or 'meters' is not in the text and round to integer, maximum elevation if there are two elevations listed but '' otherwise"
|
92 |
-
},
|
93 |
-
"Elevation Units": {
|
94 |
-
"format": "m",
|
95 |
-
"null_value": "",
|
96 |
-
"description": "'m' only if an elevation is present"
|
97 |
-
},
|
98 |
-
"Verbatim Coordinates": {
|
99 |
-
"format": "[Lat, Long | UTM | TRS]",
|
100 |
-
"null_value": "",
|
101 |
-
"description": "Verbatim coordinates as they appear on the label, fix typos to match standardized GPS coordinate format"
|
102 |
-
},
|
103 |
-
"Datum": {
|
104 |
-
"format": "[WGS84, NAD23 etc.]",
|
105 |
-
"null_value": "",
|
106 |
-
"description": "GPS Datum of coordinates on label; empty string "" if GPS coordinates are not in OCR"
|
107 |
-
},
|
108 |
-
"Cultivated": {
|
109 |
-
"format": "yes",
|
110 |
-
"null_value": "",
|
111 |
-
"description": "Indicates if specimen was grown in cultivation"
|
112 |
-
},
|
113 |
-
"Habitat": {
|
114 |
-
"format": "verbatim",
|
115 |
-
"null_value": "",
|
116 |
-
"description": "Description of habitat or location where specimen was collected, ignore descriptions of the plant itself"
|
117 |
-
},
|
118 |
-
"Collectors": {
|
119 |
-
"format": "[Collector]",
|
120 |
-
"null_value": "not present",
|
121 |
-
"description": "Full name of person (i.e., agent) who collected the specimen; if more than one person then separate the names with commas"
|
122 |
-
},
|
123 |
-
"Collector Number": {
|
124 |
-
"format": "[Collector No.]",
|
125 |
-
"null_value": "s.n.",
|
126 |
-
"description": "Sequential number assigned to collection, associated with the collector"
|
127 |
-
},
|
128 |
-
"Verbatim Date": {
|
129 |
-
"format": "verbatim",
|
130 |
-
"null_value": "s.d.",
|
131 |
-
"description": "Date of collection exactly as it appears on the label"
|
132 |
-
},
|
133 |
-
"Date": {
|
134 |
-
"format": "[yyyy-mm-dd]",
|
135 |
-
"null_value": "",
|
136 |
-
"description": "Date of collection formatted as year, month, and day; zeros may be used for unknown values i.e., 0000-00-00 if no date, YYYY-00-00 if only year, YYYY-MM-00 if no day"
|
137 |
-
},
|
138 |
-
"End Date": {
|
139 |
-
"format": "[yyyy-mm-dd]",
|
140 |
-
"null_value": "",
|
141 |
-
"description": "If date range is listed, later date of collection range"
|
142 |
-
}
|
143 |
-
}"""
|
144 |
-
|
145 |
-
structure = """{"Dictionary":
|
146 |
-
{
|
147 |
-
"Catalog Number": [Catalog Number],
|
148 |
-
"Genus": [Genus],
|
149 |
-
"Species": [species],
|
150 |
-
"subspecies": [subspecies],
|
151 |
-
"variety": [variety],
|
152 |
-
"forma": [forma],
|
153 |
-
"Country": [Country],
|
154 |
-
"State": [State],
|
155 |
-
"County": [County],
|
156 |
-
"Locality Name": [Locality Name],
|
157 |
-
"Min Elevation": [Min Elevation],
|
158 |
-
"Max Elevation": [Max Elevation],
|
159 |
-
"Elevation Units": [Elevation Units],
|
160 |
-
"Verbatim Coordinates": [Verbatim Coordinates],
|
161 |
-
"Datum": [Datum],
|
162 |
-
"Cultivated": [Cultivated],
|
163 |
-
"Habitat": [Habitat],
|
164 |
-
"Collectors": [Collectors],
|
165 |
-
"Collector Number": [Collector Number],
|
166 |
-
"Verbatim Date": [Verbatim Date],
|
167 |
-
"Date": [Date],
|
168 |
-
"End Date": [End Date]
|
169 |
-
},
|
170 |
-
"SpeciesName": {"taxonomy": [Genus_species]}}"""
|
171 |
-
|
172 |
-
prompt = f"""I'm providing you with a set of rules, an unstructured OCR text, and a reference dictionary (domain knowledge). Your task is to convert the OCR text into a structured dictionary that matches the structure of the reference dictionary. Please follow the rules strictly.
|
173 |
-
The rules are as follows:
|
174 |
-
{set_rules}
|
175 |
-
The unstructured OCR text is:
|
176 |
-
{self.OCR}
|
177 |
-
The reference dictionary, which provides an example of the output structure and has an embedding distance of {self.similarity} to the OCR, is:
|
178 |
-
{self.domain_knowledge_example}
|
179 |
-
Some dictionary fields have special requirements. These requirements specify the format for each field, and are given below:
|
180 |
-
{umich_all_asia_rules}
|
181 |
-
Please refactor the OCR text into a dictionary, following the rules and the reference structure:
|
182 |
-
{structure}
|
183 |
-
"""
|
184 |
-
|
185 |
-
xlsx_headers = ["Catalog Number","Genus","Species","subspecies","variety","forma","Country","State","County","Locality Name","Min Elevation","Max Elevation","Elevation Units","Verbatim Coordinates","Datum","Cultivated","Habitat","Collectors","Collector Number","Verbatim Date","Date","End Date"]
|
186 |
-
|
187 |
-
|
188 |
-
return prompt, self.n_fields, xlsx_headers
|
189 |
-
|
190 |
-
def prompt_v1_verbose_noDomainKnowledge(self, OCR=None):
|
191 |
-
self.OCR = OCR or self.OCR
|
192 |
-
self.n_fields = 22 or self.n_fields
|
193 |
-
|
194 |
-
set_rules = """
|
195 |
-
Please note that your task is to generate a dictionary, following the below rules:
|
196 |
-
1. Refactor the unstructured OCR text into a dictionary based on the reference dictionary structure (ref_dict).
|
197 |
-
2. Each field of OCR corresponds to a column of the ref_dict. You should correctly map the values from OCR to the respective fields in ref_dict.
|
198 |
-
3. If the OCR is mostly empty and contains substantially less text than the ref_dict examples, then only return "None".
|
199 |
-
4. If there is a field in the ref_dict that does not have a corresponding value in the OCR text, fill it based on your knowledge but don't generate new information.
|
200 |
-
5. Do not use any text from the ref_dict values in the new dict, but you must use the headers from ref_dict.
|
201 |
-
6. Duplicate dictionary fields are not allowed.
|
202 |
-
7. Only return the new dictionary. You should not explain your answer.
|
203 |
-
8. Your output should be a Python dictionary represented as a JSON string.
|
204 |
-
"""
|
205 |
-
|
206 |
-
umich_all_asia_rules = """{
|
207 |
-
"Catalog Number": {
|
208 |
-
"format": "[Catalog Number]",
|
209 |
-
"null_value": "",
|
210 |
-
"description": "The barcode identifier, typically a number with at least 6 digits, but fewer than 30 digits"
|
211 |
-
},
|
212 |
-
"Genus": {
|
213 |
-
"format": "[Genus] or '[Family] indet' if no genus",
|
214 |
-
"null_value": "",
|
215 |
-
"description": "Taxonomic determination to genus, do capitalize genus"
|
216 |
-
},
|
217 |
-
"Species": {
|
218 |
-
"format": "[species] or 'indet' if no species",
|
219 |
-
"null_value": "",
|
220 |
-
"description": "Taxonomic determination to species, do not capitalize species"
|
221 |
-
},
|
222 |
-
"subspecies": {
|
223 |
-
"format": "[subspecies]",
|
224 |
-
"null_value": "",
|
225 |
-
"description": "Taxonomic determination to subspecies (subsp.)"
|
226 |
-
},
|
227 |
-
"variety": {
|
228 |
-
"format": "[variety]",
|
229 |
-
"null_value": "",
|
230 |
-
"description": "Taxonomic determination to variety (var)"
|
231 |
-
},
|
232 |
-
"forma": {
|
233 |
-
"format": "[form]",
|
234 |
-
"null_value": "",
|
235 |
-
"description": "Taxonomic determination to form (f.)"
|
236 |
-
},
|
237 |
-
"Country": {
|
238 |
-
"format": "[Country]",
|
239 |
-
"null_value": "",
|
240 |
-
"description": "Country that corresponds to the current geographic location of collection; capitalize first letter of each word; use the entire location name even if an abbreviation is given"
|
241 |
-
},
|
242 |
-
"State": {
|
243 |
-
"format": "[Adm. Division 1]",
|
244 |
-
"null_value": "",
|
245 |
-
"description": "Administrative division 1 that corresponds to the current geographic location of collection; capitalize first letter of each word"
|
246 |
-
},
|
247 |
-
"County": {
|
248 |
-
"format": "[Adm. Division 2]",
|
249 |
-
"null_value": "",
|
250 |
-
"description": "Administrative division 2 that corresponds to the current geographic location of collection; capitalize first letter of each word"
|
251 |
-
},
|
252 |
-
"Locality Name": {
|
253 |
-
"format": "verbatim, if no geographic info: 'no data provided on label of catalog no: [######]', or if illegible: 'locality present but illegible/not translated for catalog no: #######', or if no named locality: 'no named locality for catalog no: #######'",
|
254 |
-
"description": "Description of geographic location or landscape"
|
255 |
-
},
|
256 |
-
"Min Elevation": {
|
257 |
-
"format": "elevation integer",
|
258 |
-
"null_value": "",
|
259 |
-
"description": "Elevation or altitude in meters, convert from feet to meters if 'm' or 'meters' is not in the text and round to integer, default field for elevation if a range is not given"
|
260 |
-
},
|
261 |
-
"Max Elevation": {
|
262 |
-
"format": "elevation integer",
|
263 |
-
"null_value": "",
|
264 |
-
"description": "Elevation or altitude in meters, convert from feet to meters if 'm' or 'meters' is not in the text and round to integer, maximum elevation if there are two elevations listed but '' otherwise"
|
265 |
-
},
|
266 |
-
"Elevation Units": {
|
267 |
-
"format": "m",
|
268 |
-
"null_value": "",
|
269 |
-
"description": "'m' only if an elevation is present"
|
270 |
-
},
|
271 |
-
"Verbatim Coordinates": {
|
272 |
-
"format": "[Lat, Long | UTM | TRS]",
|
273 |
-
"null_value": "",
|
274 |
-
"description": "Verbatim coordinates as they appear on the label, fix typos to match standardized GPS coordinate format"
|
275 |
-
},
|
276 |
-
"Datum": {
|
277 |
-
"format": "[WGS84, NAD23 etc.]",
|
278 |
-
"null_value": "",
|
279 |
-
"description": "GPS Datum of coordinates on label; empty string "" if GPS coordinates are not in OCR"
|
280 |
-
},
|
281 |
-
"Cultivated": {
|
282 |
-
"format": "yes",
|
283 |
-
"null_value": "",
|
284 |
-
"description": "Indicates if specimen was grown in cultivation"
|
285 |
-
},
|
286 |
-
"Habitat": {
|
287 |
-
"format": "verbatim",
|
288 |
-
"null_value": "",
|
289 |
-
"description": "Description of habitat or location where specimen was collected, ignore descriptions of the plant itself"
|
290 |
-
},
|
291 |
-
"Collectors": {
|
292 |
-
"format": "[Collector]",
|
293 |
-
"null_value": "not present",
|
294 |
-
"description": "Full name of person (i.e., agent) who collected the specimen; if more than one person then separate the names with commas"
|
295 |
-
},
|
296 |
-
"Collector Number": {
|
297 |
-
"format": "[Collector No.]",
|
298 |
-
"null_value": "s.n.",
|
299 |
-
"description": "Sequential number assigned to collection, associated with the collector"
|
300 |
-
},
|
301 |
-
"Verbatim Date": {
|
302 |
-
"format": "verbatim",
|
303 |
-
"null_value": "s.d.",
|
304 |
-
"description": "Date of collection exactly as it appears on the label"
|
305 |
-
},
|
306 |
-
"Date": {
|
307 |
-
"format": "[yyyy-mm-dd]",
|
308 |
-
"null_value": "",
|
309 |
-
"description": "Date of collection formatted as year, month, and day; zeros may be used for unknown values i.e., 0000-00-00 if no date, YYYY-00-00 if only year, YYYY-MM-00 if no day"
|
310 |
-
},
|
311 |
-
"End Date": {
|
312 |
-
"format": "[yyyy-mm-dd]",
|
313 |
-
"null_value": "",
|
314 |
-
"description": "If date range is listed, later date of collection range"
|
315 |
-
}
|
316 |
-
}"""
|
317 |
-
|
318 |
-
structure = """{"Dictionary":
|
319 |
-
{
|
320 |
-
"Catalog Number": [Catalog Number],
|
321 |
-
"Genus": [Genus],
|
322 |
-
"Species": [species],
|
323 |
-
"subspecies": [subspecies],
|
324 |
-
"variety": [variety],
|
325 |
-
"forma": [forma],
|
326 |
-
"Country": [Country],
|
327 |
-
"State": [State],
|
328 |
-
"County": [County],
|
329 |
-
"Locality Name": [Locality Name],
|
330 |
-
"Min Elevation": [Min Elevation],
|
331 |
-
"Max Elevation": [Max Elevation],
|
332 |
-
"Elevation Units": [Elevation Units],
|
333 |
-
"Verbatim Coordinates": [Verbatim Coordinates],
|
334 |
-
"Datum": [Datum],
|
335 |
-
"Cultivated": [Cultivated],
|
336 |
-
"Habitat": [Habitat],
|
337 |
-
"Collectors": [Collectors],
|
338 |
-
"Collector Number": [Collector Number],
|
339 |
-
"Verbatim Date": [Verbatim Date],
|
340 |
-
"Date": [Date],
|
341 |
-
"End Date": [End Date]
|
342 |
-
},
|
343 |
-
"SpeciesName": {"taxonomy": [Genus_species]}}"""
|
344 |
-
|
345 |
-
prompt = f"""I'm providing you with a set of rules, an unstructured OCR text, and a reference dictionary (domain knowledge). Your task is to convert the OCR text into a structured dictionary that matches the structure of the reference dictionary. Please follow the rules strictly.
|
346 |
-
The rules are as follows:
|
347 |
-
{set_rules}
|
348 |
-
The unstructured OCR text is:
|
349 |
-
{self.OCR}
|
350 |
-
Some dictionary fields have special requirements. These requirements specify the format for each field, and are given below:
|
351 |
-
{umich_all_asia_rules}
|
352 |
-
Please refactor the OCR text into a dictionary, following the rules and the reference structure:
|
353 |
-
{structure}
|
354 |
-
"""
|
355 |
-
|
356 |
-
xlsx_headers = ["Catalog Number","Genus","Species","subspecies","variety","forma","Country","State","County","Locality Name","Min Elevation","Max Elevation","Elevation Units","Verbatim Coordinates","Datum","Cultivated","Habitat","Collectors","Collector Number","Verbatim Date","Date","End Date"]
|
357 |
-
|
358 |
-
return prompt, self.n_fields, xlsx_headers
|
359 |
-
|
360 |
-
def prompt_v2_json_rules(self, OCR=None):
|
361 |
-
self.OCR = OCR or self.OCR
|
362 |
-
self.n_fields = 26 or self.n_fields
|
363 |
-
|
364 |
-
set_rules = """
|
365 |
-
1. Refactor the unstructured OCR text into a dictionary based on the JSON structure outlined below.
|
366 |
-
2. You should map the unstructured OCR text to the appropriate JSON key and then populate the field based on its rules.
|
367 |
-
3. Some JSON key fields are permitted to remain empty if the corresponding information is not found in the unstructured OCR text.
|
368 |
-
4. Ignore any information in the OCR text that doesn't fit into the defined JSON structure.
|
369 |
-
5. Duplicate dictionary fields are not allowed.
|
370 |
-
6. Ensure that all JSON keys are in lowercase.
|
371 |
-
7. Ensure that new JSON field values follow sentence case capitalization.
|
372 |
-
7. Ensure all key-value pairs in the JSON dictionary strictly adhere to the format and data types specified in the template.
|
373 |
-
8. Ensure the output JSON string is valid JSON format. It should not have trailing commas or unquoted keys.
|
374 |
-
9. Only return a JSON dictionary represented as a string. You should not explain your answer.
|
375 |
-
"""
|
376 |
-
|
377 |
-
dictionary_field_format_descriptions = """
|
378 |
-
The next section of instructions outlines how to format the JSON dictionary. The keys are the same as those of the final formatted JSON object.
|
379 |
-
For each key there is a format requirement that specifies how to transcribe the information for that key.
|
380 |
-
The possible formatting options are:
|
381 |
-
1. "verbatim transcription" - field is populated with verbatim text from the unformatted OCR.
|
382 |
-
2. "spell check transcription" - field is populated with spelling corrected text from the unformatted OCR.
|
383 |
-
3. "boolean yes no" - field is populated with only yes or no.
|
384 |
-
4. "integer" - field is populated with only an integer.
|
385 |
-
5. "[list]" - field is populated from one of the values in the list.
|
386 |
-
6. "yyyy-mm-dd" - field is populated with a date in the format year-month-day.
|
387 |
-
The desired null value is also given. Populate the field with the null value of the information for that key is not present in the unformatted OCR text.
|
388 |
-
"""
|
389 |
-
|
390 |
-
json_template_rules = """
|
391 |
-
{"Dictionary":{
|
392 |
-
"catalog_number": {
|
393 |
-
"format": "verbatim transcription",
|
394 |
-
"null_value": "",
|
395 |
-
"description": "The barcode identifier, typically a number with at least 6 digits, but fewer than 30 digits."
|
396 |
-
},
|
397 |
-
"genus": {
|
398 |
-
"format": "verbatim transcription",
|
399 |
-
"null_value": "",
|
400 |
-
"description": "Taxonomic determination to genus. Genus must be capitalized. If genus is not present use the taxonomic family name followed by the word 'indet'."
|
401 |
-
},
|
402 |
-
"species": {
|
403 |
-
"format": "verbatim transcription",
|
404 |
-
"null_value": "",
|
405 |
-
"description": "Taxonomic determination to species, do not capitalize species."
|
406 |
-
},
|
407 |
-
"subspecies": {
|
408 |
-
"format": "verbatim transcription",
|
409 |
-
"null_value": "",
|
410 |
-
"description": "Taxonomic determination to subspecies (subsp.)."
|
411 |
-
},
|
412 |
-
"variety": {
|
413 |
-
"format": "verbatim transcription",
|
414 |
-
"null_value": "",
|
415 |
-
"description": "Taxonomic determination to variety (var)."
|
416 |
-
},
|
417 |
-
"forma": {
|
418 |
-
"format": "verbatim transcription",
|
419 |
-
"null_value": "",
|
420 |
-
"description": "Taxonomic determination to form (f.)."
|
421 |
-
},
|
422 |
-
"country": {
|
423 |
-
"format": "spell check transcription",
|
424 |
-
"null_value": "",
|
425 |
-
"description": "Country that corresponds to the current geographic location of collection. Capitalize first letter of each word. If abbreviation is given populate field with the full spelling of the country's name."
|
426 |
-
},
|
427 |
-
"state": {
|
428 |
-
"format": "spell check transcription",
|
429 |
-
"null_value": "",
|
430 |
-
"description": "Administrative division 1 that corresponds to the current geographic location of collection. Capitalize first letter of each word. Administrative division 1 is equivalent to a U.S. State."
|
431 |
-
},
|
432 |
-
"county": {
|
433 |
-
"format": "spell check transcription",
|
434 |
-
"null_value": "",
|
435 |
-
"description": "Administrative division 2 that corresponds to the current geographic location of collection; capitalize first letter of each word. Administrative division 2 is equivalent to a U.S. county, parish, borough."
|
436 |
-
},
|
437 |
-
"locality_name": {
|
438 |
-
"format": "verbatim transcription",
|
439 |
-
"null_value": "",
|
440 |
-
"description": "Description of geographic location, landscape, landmarks, regional features, nearby places, or any contextual information aiding in pinpointing the exact origin or site of the specimen."
|
441 |
-
},
|
442 |
-
"min_elevation": {
|
443 |
-
"format": "integer",
|
444 |
-
"null_value": "",
|
445 |
-
"description": "Minimum elevation or altitude in meters. Only if units are explicit then convert from feet ('ft' or 'ft.' or 'feet') to meters ('m' or 'm.' or 'meters'). Round to integer."
|
446 |
-
},
|
447 |
-
"max_elevation": {
|
448 |
-
"format": "integer",
|
449 |
-
"null_value": "",
|
450 |
-
"description": "Maximum elevation or altitude in meters. If only one elevation is present, then max_elevation should be set to the null_value. Only if units are explicit then convert from feet ('ft' or 'ft.' or 'feet') to meters ('m' or 'm.' or 'meters'). Round to integer."
|
451 |
-
},
|
452 |
-
"elevation_units": {
|
453 |
-
"format": "spell check transcription",
|
454 |
-
"null_value": "",
|
455 |
-
"description": "Elevation units must be meters. If min_elevation field is populated, then elevation_units: 'm'. Otherwise elevation_units: ''."
|
456 |
-
},
|
457 |
-
"verbatim_coordinates": {
|
458 |
-
"format": "verbatim transcription",
|
459 |
-
"null_value": "",
|
460 |
-
"description": "Verbatim location coordinates as they appear on the label. Do not convert formats. Possible coordinate types are one of [Lat, Long, UTM, TRS]."
|
461 |
-
},
|
462 |
-
"decimal_coordinates": {
|
463 |
-
"format": "spell check transcription",
|
464 |
-
"null_value": "",
|
465 |
-
"description": "Correct and convert the verbatim location coordinates to conform with the decimal degrees GPS coordinate format."
|
466 |
-
},
|
467 |
-
"datum": {
|
468 |
-
"format": "[WGS84, WGS72, WGS66, WGS60, NAD83, NAD27, OSGB36, ETRS89, ED50, GDA94, JGD2011, Tokyo97, KGD2002, TWD67, TWD97, BJS54, XAS80, GCJ-02, BD-09, PZ-90.11, GTRF, CGCS2000, ITRF88, ITRF89, ITRF90, ITRF91, ITRF92, ITRF93, ITRF94, ITRF96, ITRF97, ITRF2000, ITRF2005, ITRF2008, ITRF2014, Hong Kong Principal Datum, SAD69]",
|
469 |
-
"null_value": "",
|
470 |
-
"description": "Datum of location coordinates. Possible values are include in the format list. Leave field blank if unclear."
|
471 |
-
},
|
472 |
-
"cultivated": {
|
473 |
-
"format": "boolean yes no",
|
474 |
-
"null_value": "",
|
475 |
-
"description": "Cultivated plants are intentionally grown by humans. In text descriptions, look for planting dates, garden locations, ornamental, cultivar names, garden, or farm to indicate cultivated plant."
|
476 |
-
},
|
477 |
-
"habitat": {
|
478 |
-
"format": "verbatim transcription",
|
479 |
-
"null_value": "",
|
480 |
-
"description": "Description of a plant's habitat or the location where the specimen was collected. Ignore descriptions of the plant itself."
|
481 |
-
},
|
482 |
-
"plant_description": {
|
483 |
-
"format": "verbatim transcription",
|
484 |
-
"null_value": "",
|
485 |
-
"description": "Description of plant features such as leaf shape, size, color, stem texture, height, flower structure, scent, fruit or seed characteristics, root system type, overall growth habit and form, any notable aroma or secretions, presence of hairs or bristles, and any other distinguishing morphological or physiological characteristics."
|
486 |
-
},
|
487 |
-
"collectors": {
|
488 |
-
"format": "verbatim transcription",
|
489 |
-
"null_value": "not present",
|
490 |
-
"description": "Full name(s) of the individual(s) responsible for collecting the specimen. When multiple collectors are involved, their names should be separated by commas."
|
491 |
-
},
|
492 |
-
"collector_number": {
|
493 |
-
"format": "verbatim transcription",
|
494 |
-
"null_value": "s.n.",
|
495 |
-
"description": "Unique identifier or number that denotes the specific collecting event and associated with the collector."
|
496 |
-
},
|
497 |
-
"determined_by": {
|
498 |
-
"format": "verbatim transcription",
|
499 |
-
"null_value": "",
|
500 |
-
"description": "Full name of the individual responsible for determining the taxanomic name of the specimen. Sometimes the name will be near to the characters 'det' to denote determination. This name may be isolated from other names in the unformatted OCR text."
|
501 |
-
},
|
502 |
-
"multiple_names": {
|
503 |
-
"format": "boolean yes no",
|
504 |
-
"null_value": "",
|
505 |
-
"description": "Indicate whether multiple people or collector names are present in the unformatted OCR text. If you see more than one person's name the value is 'yes'; otherwise the value is 'no'."
|
506 |
-
},
|
507 |
-
"verbatim_date": {
|
508 |
-
"format": "verbatim transcription",
|
509 |
-
"null_value": "s.d.",
|
510 |
-
"description": "Date of collection exactly as it appears on the label. Do not change the format or correct typos."
|
511 |
-
},
|
512 |
-
"date": {
|
513 |
-
"format": "yyyy-mm-dd",
|
514 |
-
"null_value": "",
|
515 |
-
"description": "Date the specimen was collected formatted as year-month-day. If specific components of the date are unknown, they should be replaced with zeros. Examples: '0000-00-00' if the entire date is unknown, 'YYYY-00-00' if only the year is known, and 'YYYY-MM-00' if year and month are known but day is not."
|
516 |
-
},
|
517 |
-
"end_date": {
|
518 |
-
"format": "yyyy-mm-dd",
|
519 |
-
"null_value": "",
|
520 |
-
"description": "If a date range is provided, this represents the later or ending date of the collection period, formatted as year-month-day. If specific components of the date are unknown, they should be replaced with zeros. Examples: '0000-00-00' if the entire end date is unknown, 'YYYY-00-00' if only the year of the end date is known, and 'YYYY-MM-00' if year and month of the end date are known but the day is not."
|
521 |
-
},
|
522 |
-
},
|
523 |
-
"SpeciesName": {
|
524 |
-
"taxonomy": [Genus_species]}
|
525 |
-
}"""
|
526 |
-
|
527 |
-
structure = """{"Dictionary":
|
528 |
-
{
|
529 |
-
"catalog_number": "",
|
530 |
-
"genus": "",
|
531 |
-
"species": "",
|
532 |
-
"subspecies": "",
|
533 |
-
"variety": "",
|
534 |
-
"forma": "",
|
535 |
-
"country": "",
|
536 |
-
"state": "",
|
537 |
-
"county": "",
|
538 |
-
"locality_name": "",
|
539 |
-
"min_elevation": "",
|
540 |
-
"max_elevation": "",
|
541 |
-
"elevation_units": "",
|
542 |
-
"verbatim_coordinates": "",
|
543 |
-
"decimal_coordinates": "",
|
544 |
-
"datum": "",
|
545 |
-
"cultivated": "",
|
546 |
-
"habitat": "",
|
547 |
-
"plant_description": "",
|
548 |
-
"collectors": "",
|
549 |
-
"collector_number": "",
|
550 |
-
"determined_by": "",
|
551 |
-
"multiple_names": "",
|
552 |
-
"verbatim_date":"" ,
|
553 |
-
"date": "",
|
554 |
-
"end_date": ""
|
555 |
-
},
|
556 |
-
"SpeciesName": {"taxonomy": ""}}"""
|
557 |
-
|
558 |
-
prompt = f"""Please help me complete this text parsing task given the following rules and unstructured OCR text. Your task is to refactor the OCR text into a structured JSON dictionary that matches the structure specified in the following rules. Please follow the rules strictly.
|
559 |
-
The rules are:
|
560 |
-
{set_rules}
|
561 |
-
The unstructured OCR text is:
|
562 |
-
{self.OCR}
|
563 |
-
{dictionary_field_format_descriptions}
|
564 |
-
This is the JSON template that includes instructions for each key:
|
565 |
-
{json_template_rules}
|
566 |
-
Please populate the following JSON dictionary based on the rules and the unformatted OCR text:
|
567 |
-
{structure}
|
568 |
-
"""
|
569 |
-
|
570 |
-
xlsx_headers = ["catalog_number","genus","species","subspecies","variety","forma","country","state","county","locality_name","min_elevation","max_elevation","elevation_units","verbatim_coordinates","decimal_coordinates","datum","cultivated","habitat","plant_description","collectors","collector_number","determined_by","multiple_names","verbatim_date","date","end_date"]
|
571 |
-
|
572 |
-
return prompt, self.n_fields, xlsx_headers
|
573 |
|
574 |
#############################################################################################
|
575 |
#############################################################################################
|
@@ -578,7 +17,7 @@ class PromptCatalog:
|
|
578 |
# These are for dynamically creating your own prompts with n-columns
|
579 |
|
580 |
|
581 |
-
def
|
582 |
self.OCR = OCR
|
583 |
|
584 |
self.rules_config_path = rules_config_path
|
@@ -588,22 +27,26 @@ class PromptCatalog:
|
|
588 |
self.json_formatting_instructions = self.rules_config['json_formatting_instructions']
|
589 |
|
590 |
self.rules_list = self.rules_config['rules']
|
591 |
-
self.n_fields = len(self.
|
592 |
|
593 |
# Set the rules for processing OCR into JSON format
|
594 |
self.rules = self.create_rules(is_palm)
|
595 |
|
596 |
-
self.structure = self.create_structure(is_palm)
|
597 |
|
|
|
|
|
|
|
|
|
598 |
if is_palm:
|
599 |
prompt = f"""Please help me complete this text parsing task given the following rules and unstructured OCR text. Your task is to refactor the OCR text into a structured JSON dictionary that matches the structure specified in the following rules. Please follow the rules strictly.
|
600 |
The rules are:
|
601 |
{self.instructions}
|
602 |
-
The unstructured OCR text is:
|
603 |
-
{self.OCR}
|
604 |
{self.json_formatting_instructions}
|
605 |
This is the JSON template that includes instructions for each key:
|
606 |
{self.rules}
|
|
|
|
|
607 |
Please populate the following JSON dictionary based on the rules and the unformatted OCR text:
|
608 |
{self.structure}
|
609 |
{self.structure}
|
@@ -613,17 +56,18 @@ class PromptCatalog:
|
|
613 |
prompt = f"""Please help me complete this text parsing task given the following rules and unstructured OCR text. Your task is to refactor the OCR text into a structured JSON dictionary that matches the structure specified in the following rules. Please follow the rules strictly.
|
614 |
The rules are:
|
615 |
{self.instructions}
|
616 |
-
The unstructured OCR text is:
|
617 |
-
{self.OCR}
|
618 |
{self.json_formatting_instructions}
|
619 |
This is the JSON template that includes instructions for each key:
|
620 |
{self.rules}
|
|
|
|
|
621 |
Please populate the following JSON dictionary based on the rules and the unformatted OCR text:
|
622 |
{self.structure}
|
623 |
"""
|
624 |
-
xlsx_headers = self.generate_xlsx_headers(is_palm)
|
625 |
|
626 |
-
return prompt, self.n_fields, xlsx_headers
|
|
|
627 |
|
628 |
def load_rules_config(self):
|
629 |
with open(self.rules_config_path, 'r') as stream:
|
@@ -634,64 +78,31 @@ class PromptCatalog:
|
|
634 |
return None
|
635 |
|
636 |
def create_rules(self, is_palm=False):
|
637 |
-
|
638 |
-
# Start with a structure for the "Dictionary" section where each key contains its rules
|
639 |
-
dictionary_structure = {
|
640 |
-
key: {
|
641 |
-
'description': value['description'],
|
642 |
-
'format': value['format'],
|
643 |
-
'null_value': value.get('null_value', '')
|
644 |
-
} for key, value in self.rules_list['Dictionary'].items()
|
645 |
-
}
|
646 |
-
|
647 |
-
# Convert the structure to a JSON string without indentation
|
648 |
-
structure_json_str = json.dumps(dictionary_structure, sort_keys=False)
|
649 |
-
return structure_json_str
|
650 |
-
|
651 |
-
else:
|
652 |
-
# Start with a structure for the "Dictionary" section where each key contains its rules
|
653 |
-
dictionary_structure = {
|
654 |
-
key: {
|
655 |
-
'description': value['description'],
|
656 |
-
'format': value['format'],
|
657 |
-
'null_value': value.get('null_value', '')
|
658 |
-
} for key, value in self.rules_list['Dictionary'].items()
|
659 |
-
}
|
660 |
-
|
661 |
-
# Combine both sections into the overall structure
|
662 |
-
full_structure = {
|
663 |
-
"Dictionary": dictionary_structure,
|
664 |
-
"SpeciesName": self.rules_list['SpeciesName']
|
665 |
-
}
|
666 |
|
667 |
-
|
668 |
-
|
669 |
-
|
670 |
|
671 |
def create_structure(self, is_palm=False):
|
672 |
-
|
673 |
-
|
674 |
-
dictionary_structure = {key: "" for key in self.rules_list['Dictionary'].keys()}
|
675 |
|
676 |
-
|
677 |
-
|
678 |
-
|
679 |
-
|
680 |
-
|
681 |
-
|
682 |
-
|
683 |
-
|
684 |
-
|
685 |
|
686 |
-
|
687 |
-
|
688 |
-
|
689 |
-
|
690 |
-
}
|
691 |
|
692 |
-
# Convert the structure to a JSON string with indentation for readability
|
693 |
-
structure_json_str = json.dumps(full_structure, sort_keys=False)
|
694 |
-
return structure_json_str
|
695 |
|
696 |
def generate_xlsx_headers(self, is_palm):
|
697 |
# Extract headers from the 'Dictionary' keys in the JSON template rules
|
@@ -699,617 +110,6 @@ class PromptCatalog:
|
|
699 |
xlsx_headers = list(self.rules_list.keys())
|
700 |
return xlsx_headers
|
701 |
else:
|
702 |
-
xlsx_headers = list(self.rules_list
|
703 |
return xlsx_headers
|
704 |
|
705 |
-
def prompt_v2_custom_redo(self, incorrect_json, is_palm):
|
706 |
-
# Load the existing rules and structure
|
707 |
-
self.rules_config = self.load_rules_config()
|
708 |
-
# self.rules = self.create_rules(is_palm)
|
709 |
-
self.structure = self.create_structure(is_palm)
|
710 |
-
|
711 |
-
# Generate the prompt using the loaded rules and structure
|
712 |
-
if is_palm:
|
713 |
-
prompt = f"""The incorrectly formatted JSON dictionary below is not valid. It contains an error that prevents it from loading with the Python command json.loads().
|
714 |
-
The incorrectly formatted JSON dictionary below is the literal string returned by a previous function and the error may be caused by markdown formatting.
|
715 |
-
You need to return coorect JSON for the following dictionary. Most likely, a quotation mark inside of a field value has not been escaped properly with a backslash.
|
716 |
-
Given the input, please generate a JSON response. Please note that the response should not contain any special characters, including quotation marks (single ' or double \"), within the JSON values.
|
717 |
-
Escape all JSON control characters that appear in input including ampersand (&) and other control characters.
|
718 |
-
Ensure all key-value pairs in the JSON dictionary strictly adhere to the format and data types specified in the template.
|
719 |
-
Ensure the output JSON string is valid JSON format. It should not have trailing commas or unquoted keys.
|
720 |
-
The incorrectly formatted JSON dictionary: {incorrect_json}
|
721 |
-
The output JSON structure: {self.structure}
|
722 |
-
The output JSON structure: {self.structure}
|
723 |
-
The output JSON structure: {self.structure}
|
724 |
-
Please reformat the incorrectly formatted JSON dictionary given the output JSON structure: """
|
725 |
-
else:
|
726 |
-
prompt = f"""The incorrectly formatted JSON dictionary below is not valid. It contains an error that prevents it from loading with the Python command json.loads().
|
727 |
-
The incorrectly formatted JSON dictionary below is the literal string returned by a previous function and the error may be caused by markdown formatting.
|
728 |
-
You need to return coorect JSON for the following dictionary. Most likely, a quotation mark inside of a field value has not been escaped properly with a backslash.
|
729 |
-
Given the input, please generate a JSON response. Please note that the response should not contain any special characters, including quotation marks (single ' or double \"), within the JSON values.
|
730 |
-
Escape all JSON control characters that appear in input including ampersand (&) and other control characters.
|
731 |
-
Ensure all key-value pairs in the JSON dictionary strictly adhere to the format and data types specified in the template.
|
732 |
-
Ensure the output JSON string is valid JSON format. It should not have trailing commas or unquoted keys.
|
733 |
-
The incorrectly formatted JSON dictionary: {incorrect_json}
|
734 |
-
The output JSON structure: {self.structure}
|
735 |
-
Please reformat the incorrectly formatted JSON dictionary given the output JSON structure: """
|
736 |
-
return prompt
|
737 |
-
|
738 |
-
#############################################################################################
|
739 |
-
#############################################################################################
|
740 |
-
#############################################################################################
|
741 |
-
#############################################################################################
|
742 |
-
def prompt_gpt_redo_v1(self, incorrect_json):
|
743 |
-
structure = """Below is the correct JSON formatting. Modify the text to conform to the following format, fixing the incorrect JSON:
|
744 |
-
{"Dictionary":
|
745 |
-
{
|
746 |
-
"Catalog Number": [Catalog Number],
|
747 |
-
"Genus": [Genus],
|
748 |
-
"Species": [species],
|
749 |
-
"subspecies": [subspecies],
|
750 |
-
"variety": [variety],
|
751 |
-
"forma": [forma],
|
752 |
-
"Country": [Country],
|
753 |
-
"State": [State],
|
754 |
-
"County": [County],
|
755 |
-
"Locality Name": [Locality Name],
|
756 |
-
"Min Elevation": [Min Elevation],
|
757 |
-
"Max Elevation": [Max Elevation],
|
758 |
-
"Elevation Units": [Elevation Units],
|
759 |
-
"Verbatim Coordinates": [Verbatim Coordinates],
|
760 |
-
"Datum": [Datum],
|
761 |
-
"Cultivated": [Cultivated],
|
762 |
-
"Habitat": [Habitat],
|
763 |
-
"Collectors": [Collectors],
|
764 |
-
"Collector Number": [Collector Number],
|
765 |
-
"Verbatim Date": [Verbatim Date],
|
766 |
-
"Date": [Date],
|
767 |
-
"End Date": [End Date]
|
768 |
-
},
|
769 |
-
"SpeciesName": {"taxonomy": [Genus_species]}}"""
|
770 |
-
|
771 |
-
prompt = f"""This text is supposed to be JSON, but it contains an error that prevents it from loading with the Python command json.loads().
|
772 |
-
You need to return coorect JSON for the following dictionary. Most likely, a quotation mark inside of a field value has not been escaped properly with a backslash.
|
773 |
-
Given the input, please generate a JSON response. Please note that the response should not contain any special characters, including quotation marks (single ' or double \"), within the JSON values.
|
774 |
-
Escape all JSON control characters that appear in input including ampersand (&) and other control characters.
|
775 |
-
Ensure all key-value pairs in the JSON dictionary strictly adhere to the format and data types specified in the template.
|
776 |
-
Ensure the output JSON string is valid JSON format. It should not have trailing commas or unquoted keys.
|
777 |
-
The incorrectly formatted JSON dictionary: {incorrect_json}
|
778 |
-
The output JSON structure: {structure}
|
779 |
-
The refactored JSON disctionary: """
|
780 |
-
return prompt
|
781 |
-
|
782 |
-
def prompt_gpt_redo_v2(self, incorrect_json):
|
783 |
-
structure = """
|
784 |
-
{"Dictionary":{
|
785 |
-
"catalog_number": "",
|
786 |
-
"genus": "",
|
787 |
-
"species": "".
|
788 |
-
"subspecies": "",
|
789 |
-
"variety": "",
|
790 |
-
"forma":"",
|
791 |
-
"country": "",
|
792 |
-
"state": "",
|
793 |
-
"county": "",
|
794 |
-
"locality_name": "",
|
795 |
-
"min_elevation": "",
|
796 |
-
"max_elevation": "",
|
797 |
-
"elevation_units": "',
|
798 |
-
"verbatim_coordinates": "",
|
799 |
-
"decimal_coordinates": "",
|
800 |
-
"datum": "",
|
801 |
-
"cultivated": "",
|
802 |
-
"habitat": "",
|
803 |
-
"plant_description": "",
|
804 |
-
"collectors": "",
|
805 |
-
"collector_number": "",
|
806 |
-
"determined_by": "",
|
807 |
-
"multiple_names": "',
|
808 |
-
"verbatim_date": "",
|
809 |
-
"date": "",
|
810 |
-
"end_date": "",
|
811 |
-
},
|
812 |
-
"SpeciesName": {"taxonomy": [Genus_species]}}"""
|
813 |
-
|
814 |
-
prompt = f"""This text is supposed to be JSON, but it contains an error that prevents it from loading with the Python command json.loads().
|
815 |
-
You need to return coorect JSON for the following dictionary. Most likely, a quotation mark inside of a field value has not been escaped properly with a backslash.
|
816 |
-
Given the input, please generate a JSON response. Please note that the response should not contain any special characters, including quotation marks (single ' or double \"), within the JSON values.
|
817 |
-
Escape all JSON control characters that appear in input including ampersand (&) and other control characters.
|
818 |
-
Ensure all key-value pairs in the JSON dictionary strictly adhere to the format and data types specified in the template.
|
819 |
-
Ensure the output JSON string is valid JSON format. It should not have trailing commas or unquoted keys.
|
820 |
-
The incorrectly formatted JSON dictionary: {incorrect_json}
|
821 |
-
The output JSON structure: {structure}
|
822 |
-
The refactored JSON disctionary: """
|
823 |
-
return prompt
|
824 |
-
#####################################################################################################################################
|
825 |
-
#####################################################################################################################################
|
826 |
-
def prompt_v1_palm2(self, in_list, out_list, OCR=None):
|
827 |
-
self.OCR = OCR or self.OCR
|
828 |
-
set_rules = """1. Your job is to return a new dict based on the structure of the reference dict ref_dict and these are your rules.
|
829 |
-
2. You must look at ref_dict and refactor the new text called OCR to match the same formatting.
|
830 |
-
3. OCR contains unstructured text inside of [], use your knowledge to put the OCR text into the correct ref_dict column.
|
831 |
-
4. If OCR is mostly empty and contains substantially less text than the ref_dict examples, then only return "None" and skip all other steps.
|
832 |
-
5. If there is a field that does not have a direct proxy in the OCR text, you can fill it in based on your knowledge, but you cannot generate new information.
|
833 |
-
6. Never put text from the ref_dict values into the new dict, but you must use the headers from ref_dict.
|
834 |
-
7. There cannot be duplicate dictionary fields.
|
835 |
-
8. Only return the new dict, do not explain your answer.
|
836 |
-
9. Do not include quotation marks in content, only use quotation marks to represent values in dictionaries.
|
837 |
-
10. For GPS coordinates only use Decimal Degrees (D.D°)
|
838 |
-
11. "Given the input text, please generate a JSON response. Please note that the response should not contain any special characters, including quotation marks (single ' or double \"), within the JSON values."""
|
839 |
-
|
840 |
-
umich_all_asia_rules = """
|
841 |
-
"Catalog Number" - {"format": "[barcode]", "null_value": "", "description": the barcode identifier, typically a number with at least 6 digits, but fewer than 30 digits}
|
842 |
-
"Genus" - {"format": "[Genus]" or "[Family] indet" if no genus", "null_value": "", "description": taxonomic determination to genus, do captalize genus}
|
843 |
-
"Species"- {"format": "[species]" or "indet" if no species, "null_value": "", "description": taxonomic determination to species, do not captalize species}
|
844 |
-
"subspecies" - {"format": "[subspecies]", "null_value": "", "description": taxonomic determination to subspecies (subsp.)}
|
845 |
-
"variety" - {"format": "[variety]", "null_value": "", "description": taxonomic determination to variety (var)}
|
846 |
-
"forma" - {"format": "[form]", "null_value": "", "description": taxonomic determination to form (f.)}
|
847 |
-
|
848 |
-
"Country" - {"format": "[Country]", "null_value": "no data", "description": Country that corresponds to the current geographic location of collection; capitalize first letter of each word; use the entire location name even if an abreviation is given}
|
849 |
-
"State" - {"format": "[Adm. Division 1]", "null_value": "no data", "description": Administrative division 1 that corresponds to the current geographic location of collection; capitalize first letter of each word}
|
850 |
-
"County" - {"format": "[Adm. Division 2]", "null_value": "no data", "description": Administrative division 2 that corresponds to the current geographic location of collection; capitalize first letter of each word}
|
851 |
-
"Locality Name" - {"format": "verbatim", if no geographic info: "no data provided on label of catalog no: [######]", or if illegible: "locality present but illegible/not translated for catalog no: #######", or if no named locality: "no named locality for catalog no: #######", "description": "Description of geographic location or landscape"}
|
852 |
-
|
853 |
-
"Min Elevation" - {format: "elevation integer", "null_value": "","description": Elevation or altitude in meters, convert from feet to meters if 'm' or 'meters' is not in the text and round to integer, default field for elevation if a range is not given}
|
854 |
-
"Max Elevation" - {format: "elevation integer", "null_value": "","description": Elevation or altitude in meters, convert from feet to meters if 'm' or 'meters' is not in the text and round to integer, maximum elevation if there are two elevations listed but '' otherwise}
|
855 |
-
"Elevation Units" - {format: "m", "null_value": "","description": "m" only if an elevation is present}
|
856 |
-
|
857 |
-
"Verbatim Coordinates" - {"format": "[Lat, Long | UTM | TRS]", "null_value": "", "description": Convert coordinates to Decimal Degrees (D.D°) format, do not use Minutes, Seconds or quotation marks}
|
858 |
-
|
859 |
-
"Datum" - {"format": "[WGS84, NAD23 etc.]", "null_value": "not present", "description": Datum of coordinates on label; "" is GPS coordinates are not in OCR}
|
860 |
-
"Cultivated" - {"format": "yes", "null_value": "", "description": Indicates if specimen was grown in cultivation}
|
861 |
-
"Habitat" - {"format": "verbatim", "null_value": "", "description": Description of habitat or location where specimen was collected, ignore descriptions of the plant itself}
|
862 |
-
"Collectors" - {"format": "[Collector]", "null_value": "not present", "description": Full name of person (i.e., agent) who collected the specimen; if more than one person then separate the names with commas}
|
863 |
-
"Collector Number" - {"format": "[Collector No.]", "null_value": "s.n.", "description": Sequential number assigned to collection, associated with the collector}
|
864 |
-
"Verbatim Date" - {"format": "verbatim", "null_value": "s.d.", "description": Date of collection exactly as it appears on the label}
|
865 |
-
"Date" - {"format": "[yyyy-mm-dd]", "null_value": "", "description": Date of collection formatted as year, month, and day; zeros may be used for unknown values i.e. 0000-00-00 if no date, YYYY-00-00 if only year, YYYY-MM-00 if no day}
|
866 |
-
"End Date" - {"format": "[yyyy-mm-dd]", "null_value": "", "description": If date range is listed, later date of collection range}
|
867 |
-
"""
|
868 |
-
|
869 |
-
prompt = f"""Given the following set of rules:
|
870 |
-
|
871 |
-
set_rules = {set_rules}
|
872 |
-
|
873 |
-
Some dict fields have special requirements listed below. First is the column header. After the - is the format. Do not include the instructions with your response:
|
874 |
-
|
875 |
-
requirements = {umich_all_asia_rules}
|
876 |
-
|
877 |
-
Given the input, please generate a JSON response. Please note that the response should not contain any special characters, including quotation marks (single ' or double \"), within the JSON values.
|
878 |
-
|
879 |
-
input: {in_list[0]}
|
880 |
-
|
881 |
-
output: {out_list[0]}
|
882 |
-
|
883 |
-
input: {in_list[1]}
|
884 |
-
|
885 |
-
output: {out_list[1]}
|
886 |
-
|
887 |
-
input: {in_list[2]}
|
888 |
-
|
889 |
-
output: {out_list[2]}
|
890 |
-
|
891 |
-
input: {self.OCR}
|
892 |
-
|
893 |
-
output:"""
|
894 |
-
|
895 |
-
return prompt
|
896 |
-
|
897 |
-
def prompt_v1_palm2_noDomainKnowledge(self, OCR=None):
|
898 |
-
self.OCR = OCR or self.OCR
|
899 |
-
set_rules = """1. Your job is to return a new dict based on the structure of the reference dict ref_dict and these are your rules.
|
900 |
-
2. You must look at ref_dict and refactor the new text called OCR to match the same formatting.
|
901 |
-
3. OCR contains unstructured text inside of [], use your knowledge to put the OCR text into the correct ref_dict column.
|
902 |
-
4. If OCR is mostly empty and contains substantially less text than the ref_dict examples, then only return "None" and skip all other steps.
|
903 |
-
5. If there is a field that does not have a direct proxy in the OCR text, you can fill it in based on your knowledge, but you cannot generate new information.
|
904 |
-
6. Never put text from the ref_dict values into the new dict, but you must use the headers from ref_dict.
|
905 |
-
7. There cannot be duplicate dictionary fields.
|
906 |
-
8. Only return the new dict, do not explain your answer.
|
907 |
-
9. Do not include quotation marks in content, only use quotation marks to represent values in dictionaries.
|
908 |
-
10. For GPS coordinates only use Decimal Degrees (D.D°)
|
909 |
-
11. "Given the input text, please generate a JSON response. Please note that the response should not contain any special characters, including quotation marks (single ' or double \"), within the JSON values."""
|
910 |
-
|
911 |
-
umich_all_asia_rules = """
|
912 |
-
"Catalog Number" - {"format": "barcode", "null_value": "", "description": the barcode identifier, typically a number with at least 6 digits, but fewer than 30 digits}
|
913 |
-
"Genus" - {"format": "Genus" or "Family indet" if no genus", "null_value": "", "description": taxonomic determination to genus, do captalize genus}
|
914 |
-
"Species"- {"format": "species" or "indet" if no species, "null_value": "", "description": taxonomic determination to species, do not captalize species}
|
915 |
-
"subspecies" - {"format": "subspecies", "null_value": "", "description": taxonomic determination to subspecies (subsp.)}
|
916 |
-
"variety" - {"format": "variety", "null_value": "", "description": taxonomic determination to variety (var)}
|
917 |
-
"forma" - {"format": "form", "null_value": "", "description": taxonomic determination to form (f.)}
|
918 |
-
|
919 |
-
"Country" - {"format": "Country", "null_value": "no data", "description": Country that corresponds to the current geographic location of collection; capitalize first letter of each word; use the entire location name even if an abreviation is given}
|
920 |
-
"State" - {"format": "Adm. Division 1", "null_value": "no data", "description": Administrative division 1 that corresponds to the current geographic location of collection; capitalize first letter of each word}
|
921 |
-
"County" - {"format": "Adm. Division 2", "null_value": "no data", "description": Administrative division 2 that corresponds to the current geographic location of collection; capitalize first letter of each word}
|
922 |
-
"Locality Name" - {"format": "verbatim", if no geographic info: "no data provided on label of catalog no: ######", or if illegible: "locality present but illegible/not translated for catalog no: #######", or if no named locality: "no named locality for catalog no: #######", "description": "Description of geographic location or landscape"}
|
923 |
-
|
924 |
-
"Min Elevation" - {format: "elevation integer", "null_value": "","description": Elevation or altitude in meters, convert from feet to meters if 'm' or 'meters' is not in the text and round to integer, default field for elevation if a range is not given}
|
925 |
-
"Max Elevation" - {format: "elevation integer", "null_value": "","description": Elevation or altitude in meters, convert from feet to meters if 'm' or 'meters' is not in the text and round to integer, maximum elevation if there are two elevations listed but '' otherwise}
|
926 |
-
"Elevation Units" - {format: "m", "null_value": "","description": "m" only if an elevation is present}
|
927 |
-
|
928 |
-
"Verbatim Coordinates" - {"format": "Lat, Long, UTM, TRS", "null_value": "", "description": Convert coordinates to Decimal Degrees (D.D°) format, do not use Minutes, Seconds or quotation marks}
|
929 |
-
|
930 |
-
"Datum" - {"format": "WGS84, NAD23 etc.", "null_value": "not present", "description": Datum of coordinates on label; "" is GPS coordinates are not in OCR}
|
931 |
-
"Cultivated" - {"format": "yes", "null_value": "", "description": Indicates if specimen was grown in cultivation}
|
932 |
-
"Habitat" - {"format": "verbatim", "null_value": "", "description": Description of habitat or location where specimen was collected, ignore descriptions of the plant itself}
|
933 |
-
"Collectors" - {"format": "Collector", "null_value": "not present", "description": Full name of person (i.e., agent) who collected the specimen; if more than one person then separate the names with commas}
|
934 |
-
"Collector Number" - {"format": "Collector No.", "null_value": "s.n.", "description": Sequential number assigned to collection, associated with the collector}
|
935 |
-
"Verbatim Date" - {"format": "verbatim", "null_value": "s.d.", "description": Date of collection exactly as it appears on the label}
|
936 |
-
"Date" - {"format": "yyyy-mm-dd", "null_value": "", "description": Date of collection formatted as year, month, and day; zeros may be used for unknown values i.e. 0000-00-00 if no date, YYYY-00-00 if only year, YYYY-MM-00 if no day}
|
937 |
-
"End Date" - {"format": "yyyy-mm-dd", "null_value": "", "description": If date range is listed, later date of collection range}
|
938 |
-
"""
|
939 |
-
structure = """{
|
940 |
-
"Catalog Number": "",
|
941 |
-
"Genus": "",
|
942 |
-
"Species": "",
|
943 |
-
"subspecies": "",
|
944 |
-
"variety": "",
|
945 |
-
"forma": "",
|
946 |
-
"Country": "",
|
947 |
-
"State": "",
|
948 |
-
"County": "",
|
949 |
-
"Locality Name": "",
|
950 |
-
"Min Elevation": "",
|
951 |
-
"Max Elevation": "",
|
952 |
-
"Elevation Units": "",
|
953 |
-
"Verbatim Coordinates": "",
|
954 |
-
"Datum": "",
|
955 |
-
"Cultivated": "",
|
956 |
-
"Habitat": "",
|
957 |
-
"Collectors": "",
|
958 |
-
"Collector Number": "",
|
959 |
-
"Verbatim Date": "",
|
960 |
-
"Date": "",
|
961 |
-
"End Date": "",
|
962 |
-
}"""
|
963 |
-
# structure = """{
|
964 |
-
# "Catalog Number": [Catalog Number],
|
965 |
-
# "Genus": [Genus],
|
966 |
-
# "Species": [species],
|
967 |
-
# "subspecies": [subspecies],
|
968 |
-
# "variety": [variety],
|
969 |
-
# "forma": [forma],
|
970 |
-
# "Country": [Country],
|
971 |
-
# "State": [State],
|
972 |
-
# "County": [County],
|
973 |
-
# "Locality Name": [Locality Name],
|
974 |
-
# "Min Elevation": [Min Elevation],
|
975 |
-
# "Max Elevation": [Max Elevation],
|
976 |
-
# "Elevation Units": [Elevation Units],
|
977 |
-
# "Verbatim Coordinates": [Verbatim Coordinates],
|
978 |
-
# "Datum": [Datum],
|
979 |
-
# "Cultivated": [Cultivated],
|
980 |
-
# "Habitat": [Habitat],
|
981 |
-
# "Collectors": [Collectors],
|
982 |
-
# "Collector Number": [Collector Number],
|
983 |
-
# "Verbatim Date": [Verbatim Date],
|
984 |
-
# "Date": [Date],
|
985 |
-
# "End Date": [End Date]
|
986 |
-
# }"""
|
987 |
-
|
988 |
-
prompt = f"""Given the following set of rules:
|
989 |
-
set_rules = {set_rules}
|
990 |
-
Some dict fields have special requirements listed below. First is the column header. After the - is the format. Do not include the instructions with your response:
|
991 |
-
requirements = {umich_all_asia_rules}
|
992 |
-
Given the input, please generate a JSON response. Please note that the response should not contain any special characters, including quotation marks (single ' or double \"), within the JSON values.
|
993 |
-
The input unformatted OCR text: {self.OCR}
|
994 |
-
The output JSON structure: {structure}
|
995 |
-
The output JSON structure: {structure}
|
996 |
-
The output JSON structure: {structure}
|
997 |
-
The refactored JSON disctionary:"""
|
998 |
-
|
999 |
-
return prompt
|
1000 |
-
|
1001 |
-
def prompt_v2_palm2(self, OCR=None):
|
1002 |
-
self.OCR = OCR or self.OCR
|
1003 |
-
self.n_fields = 26 or self.n_fields
|
1004 |
-
|
1005 |
-
set_rules = """
|
1006 |
-
1. Refactor the unstructured OCR text into a dictionary based on the JSON structure outlined below.
|
1007 |
-
2. You should map the unstructured OCR text to the appropriate JSON key and then populate the field based on its rules.
|
1008 |
-
3. Some JSON key fields are permitted to remain empty if the corresponding information is not found in the unstructured OCR text.
|
1009 |
-
4. Ignore any information in the OCR text that doesn't fit into the defined JSON structure.
|
1010 |
-
5. Duplicate dictionary fields are not allowed.
|
1011 |
-
6. Ensure that all JSON keys are in lowercase.
|
1012 |
-
7. Ensure that new JSON field values follow sentence case capitalization.
|
1013 |
-
8. Ensure all key-value pairs in the JSON dictionary strictly adhere to the format and data types specified in the template.
|
1014 |
-
9. Ensure the output JSON string is valid JSON format. It should not have trailing commas or unquoted keys.
|
1015 |
-
10. Only return a JSON dictionary represented as a string. You should not explain your answer.
|
1016 |
-
"""
|
1017 |
-
|
1018 |
-
dictionary_field_format_descriptions = """
|
1019 |
-
The next section of instructions outlines how to format the JSON dictionary. The keys are the same as those of the final formatted JSON object.
|
1020 |
-
For each key there is a format requirement that specifies how to transcribe the information for that key.
|
1021 |
-
The possible formatting options are:
|
1022 |
-
1. "verbatim transcription" - field is populated with verbatim text from the unformatted OCR.
|
1023 |
-
2. "spell check transcription" - field is populated with spelling corrected text from the unformatted OCR.
|
1024 |
-
3. "boolean yes no" - field is populated with only yes or no.
|
1025 |
-
4. "integer" - field is populated with only an integer.
|
1026 |
-
5. "[list]" - field is populated from one of the values in the list.
|
1027 |
-
6. "yyyy-mm-dd" - field is populated with a date in the format year-month-day.
|
1028 |
-
The desired null value is also given. Populate the field with the null value of the information for that key is not present in the unformatted OCR text.
|
1029 |
-
"""
|
1030 |
-
|
1031 |
-
json_template_rules = """
|
1032 |
-
{
|
1033 |
-
"catalog_number": {
|
1034 |
-
"format": "verbatim transcription",
|
1035 |
-
"null_value": "",
|
1036 |
-
"description": "The barcode identifier, typically a number with at least 6 digits, but fewer than 30 digits."
|
1037 |
-
},
|
1038 |
-
"genus": {
|
1039 |
-
"format": "verbatim transcription",
|
1040 |
-
"null_value": "",
|
1041 |
-
"description": "Taxonomic determination to genus. Genus must be capitalized. If genus is not present use the taxonomic family name followed by the word 'indet'."
|
1042 |
-
},
|
1043 |
-
"species": {
|
1044 |
-
"format": "verbatim transcription",
|
1045 |
-
"null_value": "",
|
1046 |
-
"description": "Taxonomic determination to species, do not capitalize species."
|
1047 |
-
},
|
1048 |
-
"subspecies": {
|
1049 |
-
"format": "verbatim transcription",
|
1050 |
-
"null_value": "",
|
1051 |
-
"description": "Taxonomic determination to subspecies (subsp.)."
|
1052 |
-
},
|
1053 |
-
"variety": {
|
1054 |
-
"format": "verbatim transcription",
|
1055 |
-
"null_value": "",
|
1056 |
-
"description": "Taxonomic determination to variety (var)."
|
1057 |
-
},
|
1058 |
-
"forma": {
|
1059 |
-
"format": "verbatim transcription",
|
1060 |
-
"null_value": "",
|
1061 |
-
"description": "Taxonomic determination to form (f.)."
|
1062 |
-
},
|
1063 |
-
"country": {
|
1064 |
-
"format": "spell check transcription",
|
1065 |
-
"null_value": "",
|
1066 |
-
"description": "Country that corresponds to the current geographic location of collection. Capitalize first letter of each word. If abbreviation is given populate field with the full spelling of the country's name. Use sentence-case to capitalize proper nouns."
|
1067 |
-
},
|
1068 |
-
"state": {
|
1069 |
-
"format": "spell check transcription",
|
1070 |
-
"null_value": "",
|
1071 |
-
"description": "Administrative division 1 that corresponds to the current geographic location of collection. Capitalize first letter of each word. Administrative division 1 is equivalent to a U.S. State. Use sentence-case to capitalize proper nouns."
|
1072 |
-
},
|
1073 |
-
"county": {
|
1074 |
-
"format": "spell check transcription",
|
1075 |
-
"null_value": "",
|
1076 |
-
"description": "Administrative division 2 that corresponds to the current geographic location of collection; capitalize first letter of each word. Administrative division 2 is equivalent to a U.S. county, parish, borough. Use sentence-case to capitalize proper nouns."
|
1077 |
-
},
|
1078 |
-
"locality_name": {
|
1079 |
-
"format": "verbatim transcription",
|
1080 |
-
"null_value": "",
|
1081 |
-
"description": "Description of geographic location, landscape, landmarks, regional features, nearby places, or any contextual information aiding in pinpointing the exact origin or site of the specimen. Use sentence-case to capitalize proper nouns."
|
1082 |
-
},
|
1083 |
-
"min_elevation": {
|
1084 |
-
"format": "integer",
|
1085 |
-
"null_value": "",
|
1086 |
-
"description": "Minimum elevation or altitude in meters. Only if units are explicit then convert from feet ('ft' or 'ft.' or 'feet') to meters ('m' or 'm.' or 'meters'). Round to integer."
|
1087 |
-
},
|
1088 |
-
"max_elevation": {
|
1089 |
-
"format": "integer",
|
1090 |
-
"null_value": "",
|
1091 |
-
"description": "Maximum elevation or altitude in meters. If only one elevation is present, then max_elevation should be set to the null_value. Only if units are explicit then convert from feet ('ft' or 'ft.' or 'feet') to meters ('m' or 'm.' or 'meters'). Round to integer."
|
1092 |
-
},
|
1093 |
-
"elevation_units": {
|
1094 |
-
"format": "spell check transcription",
|
1095 |
-
"null_value": "",
|
1096 |
-
"description": "Elevation units must be meters. If min_elevation field is populated, then elevation_units: 'm'. Otherwise elevation_units: ''"
|
1097 |
-
},
|
1098 |
-
"verbatim_coordinates": {
|
1099 |
-
"format": "verbatim transcription",
|
1100 |
-
"null_value": "",
|
1101 |
-
"description": "Verbatim location coordinates as they appear on the label. Do not convert formats. Possible coordinate types are one of [Lat, Long, UTM, TRS]."
|
1102 |
-
},
|
1103 |
-
"decimal_coordinates": {
|
1104 |
-
"format": "spell check transcription",
|
1105 |
-
"null_value": "",
|
1106 |
-
"description": "Correct and convert the verbatim location coordinates to conform with the decimal degrees GPS coordinate format."
|
1107 |
-
},
|
1108 |
-
"datum": {
|
1109 |
-
"format": "[WGS84, WGS72, WGS66, WGS60, NAD83, NAD27, OSGB36, ETRS89, ED50, GDA94, JGD2011, Tokyo97, KGD2002, TWD67, TWD97, BJS54, XAS80, GCJ-02, BD-09, PZ-90.11, GTRF, CGCS2000, ITRF88, ITRF89, ITRF90, ITRF91, ITRF92, ITRF93, ITRF94, ITRF96, ITRF97, ITRF2000, ITRF2005, ITRF2008, ITRF2014, Hong Kong Principal Datum, SAD69]",
|
1110 |
-
"null_value": "",
|
1111 |
-
"description": "Datum of location coordinates. Possible values are include in the format list. Leave field blank if unclear."
|
1112 |
-
},
|
1113 |
-
"cultivated": {
|
1114 |
-
"format": "boolean yes no",
|
1115 |
-
"null_value": "",
|
1116 |
-
"description": "Cultivated plants are intentionally grown by humans. In text descriptions, look for planting dates, garden locations, ornamental, cultivar names, garden, or farm to indicate cultivated plant."
|
1117 |
-
},
|
1118 |
-
"habitat": {
|
1119 |
-
"format": "verbatim transcription",
|
1120 |
-
"null_value": "",
|
1121 |
-
"description": "Description of a plant's habitat or the location where the specimen was collected. Ignore descriptions of the plant itself. Use sentence-case to capitalize proper nouns."
|
1122 |
-
},
|
1123 |
-
"plant_description": {
|
1124 |
-
"format": "verbatim transcription",
|
1125 |
-
"null_value": "",
|
1126 |
-
"description": "Description of plant features such as leaf shape, size, color, stem texture, height, flower structure, scent, fruit or seed characteristics, root system type, overall growth habit and form, any notable aroma or secretions, presence of hairs or bristles, and any other distinguishing morphological or physiological characteristics. Use sentence-case to capitalize proper nouns."
|
1127 |
-
},
|
1128 |
-
"collectors": {
|
1129 |
-
"format": "verbatim transcription",
|
1130 |
-
"null_value": "not present",
|
1131 |
-
"description": "Full name(s) of the individual(s) responsible for collecting the specimen. Use sentence-case to capitalize proper nouns. When multiple collectors are involved, their names should be separated by commas."
|
1132 |
-
},
|
1133 |
-
"collector_number": {
|
1134 |
-
"format": "verbatim transcription",
|
1135 |
-
"null_value": "s.n.",
|
1136 |
-
"description": "Unique identifier or number that denotes the specific collecting event and associated with the collector."
|
1137 |
-
},
|
1138 |
-
"determined_by": {
|
1139 |
-
"format": "verbatim transcription",
|
1140 |
-
"null_value": "",
|
1141 |
-
"description": "Full name of the individual responsible for determining the taxanomic name of the specimen. Use sentence-case to capitalize proper nouns. Sometimes the name will be near to the characters 'det' to denote determination. This name may be isolated from other names in the unformatted OCR text."
|
1142 |
-
},
|
1143 |
-
"multiple_names": {
|
1144 |
-
"format": "boolean yes no",
|
1145 |
-
"null_value": "",
|
1146 |
-
"description": "Indicate whether multiple people or collector names are present in the unformatted OCR text. Use sentence-case to capitalize proper nouns. If you see more than one person's name the value is 'yes'; otherwise the value is 'no'."
|
1147 |
-
},
|
1148 |
-
"verbatim_date": {
|
1149 |
-
"format": "verbatim transcription",
|
1150 |
-
"null_value": "s.d.",
|
1151 |
-
"description": "Date of collection exactly as it appears on the label. Do not change the format or correct typos."
|
1152 |
-
},
|
1153 |
-
"date": {
|
1154 |
-
"format": "yyyy-mm-dd",
|
1155 |
-
"null_value": "",
|
1156 |
-
"description": "Date the specimen was collected formatted as year-month-day. If specific components of the date are unknown, they should be replaced with zeros. Examples: '0000-00-00' if the entire date is unknown, 'YYYY-00-00' if only the year is known, and 'YYYY-MM-00' if year and month are known but day is not."
|
1157 |
-
},
|
1158 |
-
"end_date": {
|
1159 |
-
"format": "yyyy-mm-dd",
|
1160 |
-
"null_value": "",
|
1161 |
-
"description": "If a date range is provided, this represents the later or ending date of the collection period, formatted as year-month-day. If specific components of the date are unknown, they should be replaced with zeros. Examples: '0000-00-00' if the entire end date is unknown, 'YYYY-00-00' if only the year of the end date is known, and 'YYYY-MM-00' if year and month of the end date are known but the day is not."
|
1162 |
-
},
|
1163 |
-
}"""
|
1164 |
-
|
1165 |
-
structure = """{"catalog_number": "",
|
1166 |
-
"genus": "",
|
1167 |
-
"species": "".
|
1168 |
-
"subspecies": "",
|
1169 |
-
"variety": "",
|
1170 |
-
"forma":"",
|
1171 |
-
"country": "",
|
1172 |
-
"state": "",
|
1173 |
-
"county": "",
|
1174 |
-
"locality_name": "",
|
1175 |
-
"min_elevation": "",
|
1176 |
-
"max_elevation": "",
|
1177 |
-
"elevation_units": "',
|
1178 |
-
"verbatim_coordinates": "",
|
1179 |
-
"decimal_coordinates": "",
|
1180 |
-
"datum": "",
|
1181 |
-
"cultivated": "",
|
1182 |
-
"habitat": "",
|
1183 |
-
"plant_description": "",
|
1184 |
-
"collectors": "",
|
1185 |
-
"collector_number": "",
|
1186 |
-
"determined_by": "",
|
1187 |
-
"multiple_names": "',
|
1188 |
-
"verbatim_date": "",
|
1189 |
-
"date": "",
|
1190 |
-
"end_date": "",
|
1191 |
-
}"""
|
1192 |
-
# structure = """{"catalog_number": [Catalog Number],
|
1193 |
-
# "genus": [Genus],
|
1194 |
-
# "species": [species],
|
1195 |
-
# "subspecies": [subspecies],
|
1196 |
-
# "variety": [variety],
|
1197 |
-
# "forma": [forma],
|
1198 |
-
# "country": [Country],
|
1199 |
-
# "state": [State],
|
1200 |
-
# "county": [County],
|
1201 |
-
# "locality_name": [Locality Name],
|
1202 |
-
# "min_elevation": [Min Elevation],
|
1203 |
-
# "max_elevation": [Max Elevation],
|
1204 |
-
# "elevation_units": [Elevation Units],
|
1205 |
-
# "verbatim_coordinates": [Verbatim Coordinates],
|
1206 |
-
# "decimal_coordinates": [Decimal Coordinates],
|
1207 |
-
# "datum": [Datum],
|
1208 |
-
# "cultivated": [boolean yes no],
|
1209 |
-
# "habitat": [Habitat Description],
|
1210 |
-
# "plant_description": [Plant Description],
|
1211 |
-
# "collectors": [Name(s) of Collectors],
|
1212 |
-
# "collector_number": [Collector Number],
|
1213 |
-
# "determined_by": [Name(s) of Taxonomist],
|
1214 |
-
# "multiple_names": [boolean yes no],
|
1215 |
-
# "verbatim_date": [Verbatim Date],
|
1216 |
-
# "date": [yyyy-mm-dd],
|
1217 |
-
# "end_date": [yyyy-mm-dd],
|
1218 |
-
# }"""
|
1219 |
-
|
1220 |
-
prompt = f"""Please help me complete this text parsing task given the following rules and unstructured OCR text. Your task is to refactor the OCR text into a structured JSON dictionary that matches the structure specified in the following rules. Please follow the rules strictly.
|
1221 |
-
The rules are:
|
1222 |
-
{set_rules}
|
1223 |
-
The unstructured OCR text is:
|
1224 |
-
{self.OCR}
|
1225 |
-
{dictionary_field_format_descriptions}
|
1226 |
-
This is the JSON template that includes instructions for each key:
|
1227 |
-
{json_template_rules}
|
1228 |
-
Please populate the following JSON dictionary based on the rules and the unformatted OCR text. The square brackets denote the locations that you should place the new structured text:
|
1229 |
-
{structure}
|
1230 |
-
{structure}
|
1231 |
-
{structure}
|
1232 |
-
"""
|
1233 |
-
|
1234 |
-
return prompt
|
1235 |
-
|
1236 |
-
def prompt_palm_redo_v1(self, incorrect_json):
|
1237 |
-
structure = """{
|
1238 |
-
"Catalog Number": [Catalog Number],
|
1239 |
-
"Genus": [Genus],
|
1240 |
-
"Species": [species],
|
1241 |
-
"subspecies": [subspecies],
|
1242 |
-
"variety": [variety],
|
1243 |
-
"forma": [forma],
|
1244 |
-
"Country": [Country],
|
1245 |
-
"State": [State],
|
1246 |
-
"County": [County],
|
1247 |
-
"Locality Name": [Locality Name],
|
1248 |
-
"Min Elevation": [Min Elevation],
|
1249 |
-
"Max Elevation": [Max Elevation],
|
1250 |
-
"Elevation Units": [Elevation Units],
|
1251 |
-
"Verbatim Coordinates": [Verbatim Coordinates],
|
1252 |
-
"Datum": [Datum],
|
1253 |
-
"Cultivated": [Cultivated],
|
1254 |
-
"Habitat": [Habitat],
|
1255 |
-
"Collectors": [Collectors],
|
1256 |
-
"Collector Number": [Collector Number],
|
1257 |
-
"Verbatim Date": [Verbatim Date],
|
1258 |
-
"Date": [Date],
|
1259 |
-
"End Date": [End Date]
|
1260 |
-
}"""
|
1261 |
-
|
1262 |
-
prompt = f"""This text is supposed to be JSON, but it contains an error that prevents it from loading with the Python command json.loads().
|
1263 |
-
You need to return coorect JSON for the following dictionary. Most likely, a quotation mark inside of a field value has not been escaped properly with a backslash.
|
1264 |
-
Given the input, please generate a JSON response. Please note that the response should not contain any special characters, including quotation marks (single ' or double \"), within the JSON values.
|
1265 |
-
Escape all JSON control characters that appear in input including ampersand (&) and other control characters.
|
1266 |
-
Ensure all key-value pairs in the JSON dictionary strictly adhere to the format and data types specified in the template.
|
1267 |
-
Ensure the output JSON string is valid JSON format. It should not have trailing commas or unquoted keys.
|
1268 |
-
The incorrectly formatted JSON dictionary: {incorrect_json}
|
1269 |
-
The output JSON structure: {structure}
|
1270 |
-
The output JSON structure: {structure}
|
1271 |
-
The output JSON structure: {structure}
|
1272 |
-
The refactored JSON disctionary: """
|
1273 |
-
return prompt
|
1274 |
-
|
1275 |
-
def prompt_palm_redo_v2(self, incorrect_json):
|
1276 |
-
structure = """{"catalog_number": "",
|
1277 |
-
"genus": "",
|
1278 |
-
"species": "".
|
1279 |
-
"subspecies": "",
|
1280 |
-
"variety": "",
|
1281 |
-
"forma":"",
|
1282 |
-
"country": "",
|
1283 |
-
"state": "",
|
1284 |
-
"county": "",
|
1285 |
-
"locality_name": "",
|
1286 |
-
"min_elevation": "",
|
1287 |
-
"max_elevation": "",
|
1288 |
-
"elevation_units": "',
|
1289 |
-
"verbatim_coordinates": "",
|
1290 |
-
"decimal_coordinates": "",
|
1291 |
-
"datum": "",
|
1292 |
-
"cultivated": "",
|
1293 |
-
"habitat": "",
|
1294 |
-
"plant_description": "",
|
1295 |
-
"collectors": "",
|
1296 |
-
"collector_number": "",
|
1297 |
-
"determined_by": "",
|
1298 |
-
"multiple_names": "',
|
1299 |
-
"verbatim_date": "",
|
1300 |
-
"date": "",
|
1301 |
-
"end_date": "",
|
1302 |
-
}"""
|
1303 |
-
|
1304 |
-
prompt = f"""This text is supposed to be JSON, but it contains an error that prevents it from loading with the Python command json.loads().
|
1305 |
-
You need to return coorect JSON for the following dictionary. Most likely, a quotation mark inside of a field value has not been escaped properly with a backslash.
|
1306 |
-
Given the input, please generate a JSON response. Please note that the response should not contain any special characters, including quotation marks (single ' or double \"), within the JSON values.
|
1307 |
-
Escape all JSON control characters that appear in input including ampersand (&) and other control characters.
|
1308 |
-
Ensure all key-value pairs in the JSON dictionary strictly adhere to the format and data types specified in the template.
|
1309 |
-
Ensure the output JSON string is valid JSON format. It should not have trailing commas or unquoted keys.
|
1310 |
-
The incorrectly formatted JSON dictionary: {incorrect_json}
|
1311 |
-
The output JSON structure: {structure}
|
1312 |
-
The output JSON structure: {structure}
|
1313 |
-
The output JSON structure: {structure}
|
1314 |
-
The refactored JSON disctionary: """
|
1315 |
-
return prompt
|
|
|
1 |
from dataclasses import dataclass
|
2 |
+
from langchain_core.pydantic_v1 import Field, create_model
|
3 |
import yaml, json
|
4 |
|
|
|
|
|
|
|
5 |
@dataclass
|
6 |
class PromptCatalog:
|
7 |
domain_knowledge_example: str = ""
|
|
|
9 |
OCR: str = ""
|
10 |
n_fields: int = 0
|
11 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
|
13 |
#############################################################################################
|
14 |
#############################################################################################
|
|
|
17 |
# These are for dynamically creating your own prompts with n-columns
|
18 |
|
19 |
|
20 |
+
def prompt_SLTP(self, rules_config_path, OCR=None, is_palm=False):
|
21 |
self.OCR = OCR
|
22 |
|
23 |
self.rules_config_path = rules_config_path
|
|
|
27 |
self.json_formatting_instructions = self.rules_config['json_formatting_instructions']
|
28 |
|
29 |
self.rules_list = self.rules_config['rules']
|
30 |
+
self.n_fields = len(self.rules_config['rules'])
|
31 |
|
32 |
# Set the rules for processing OCR into JSON format
|
33 |
self.rules = self.create_rules(is_palm)
|
34 |
|
35 |
+
self.structure, self.dictionary_structure = self.create_structure(is_palm)
|
36 |
|
37 |
+
''' between instructions and json_formatting_instructions. Made the prompt too long. Better performance without it
|
38 |
+
The unstructured OCR text is:
|
39 |
+
{self.OCR}
|
40 |
+
'''
|
41 |
if is_palm:
|
42 |
prompt = f"""Please help me complete this text parsing task given the following rules and unstructured OCR text. Your task is to refactor the OCR text into a structured JSON dictionary that matches the structure specified in the following rules. Please follow the rules strictly.
|
43 |
The rules are:
|
44 |
{self.instructions}
|
|
|
|
|
45 |
{self.json_formatting_instructions}
|
46 |
This is the JSON template that includes instructions for each key:
|
47 |
{self.rules}
|
48 |
+
The unstructured OCR text is:
|
49 |
+
{self.OCR}
|
50 |
Please populate the following JSON dictionary based on the rules and the unformatted OCR text:
|
51 |
{self.structure}
|
52 |
{self.structure}
|
|
|
56 |
prompt = f"""Please help me complete this text parsing task given the following rules and unstructured OCR text. Your task is to refactor the OCR text into a structured JSON dictionary that matches the structure specified in the following rules. Please follow the rules strictly.
|
57 |
The rules are:
|
58 |
{self.instructions}
|
|
|
|
|
59 |
{self.json_formatting_instructions}
|
60 |
This is the JSON template that includes instructions for each key:
|
61 |
{self.rules}
|
62 |
+
The unstructured OCR text is:
|
63 |
+
{self.OCR}
|
64 |
Please populate the following JSON dictionary based on the rules and the unformatted OCR text:
|
65 |
{self.structure}
|
66 |
"""
|
67 |
+
# xlsx_headers = self.generate_xlsx_headers(is_palm)
|
68 |
|
69 |
+
# return prompt, self.PromptJSONModel, self.n_fields, xlsx_headers
|
70 |
+
return prompt, self.dictionary_structure
|
71 |
|
72 |
def load_rules_config(self):
|
73 |
with open(self.rules_config_path, 'r') as stream:
|
|
|
78 |
return None
|
79 |
|
80 |
def create_rules(self, is_palm=False):
|
81 |
+
dictionary_structure = {key: value for key, value in self.rules_list.items()}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
82 |
|
83 |
+
# Convert the structure to a JSON string without indentation
|
84 |
+
structure_json_str = json.dumps(dictionary_structure, sort_keys=False)
|
85 |
+
return structure_json_str
|
86 |
|
87 |
def create_structure(self, is_palm=False):
|
88 |
+
# Create fields for the Pydantic model dynamically
|
89 |
+
fields = {key: (str, Field(default=value, description=value)) for key, value in self.rules_list.items()}
|
|
|
90 |
|
91 |
+
# Dynamically create the Pydantic model
|
92 |
+
DynamicJSONParsingModel = create_model('SLTPvA', **fields)
|
93 |
+
DynamicJSONParsingModel_use = DynamicJSONParsingModel()
|
94 |
+
|
95 |
+
# Define the structure for the "Dictionary" section
|
96 |
+
dictionary_fields = {key: (str, Field(default='', description="")) for key in self.rules_list.keys()}
|
97 |
+
|
98 |
+
# Dynamically create the "Dictionary" Pydantic model
|
99 |
+
PromptJSONModel = create_model('PromptJSONModel', **dictionary_fields)
|
100 |
|
101 |
+
# Convert the model to JSON string (for demonstration)
|
102 |
+
dictionary_structure = PromptJSONModel().dict()
|
103 |
+
structure_json_str = json.dumps(dictionary_structure, sort_keys=False, indent=4)
|
104 |
+
return structure_json_str, dictionary_structure
|
|
|
105 |
|
|
|
|
|
|
|
106 |
|
107 |
def generate_xlsx_headers(self, is_palm):
|
108 |
# Extract headers from the 'Dictionary' keys in the JSON template rules
|
|
|
110 |
xlsx_headers = list(self.rules_list.keys())
|
111 |
return xlsx_headers
|
112 |
else:
|
113 |
+
xlsx_headers = list(self.rules_list.keys())
|
114 |
return xlsx_headers
|
115 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
vouchervision/utils_LLM.py
ADDED
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Helper funcs for LLM_XXXXX.py
|
2 |
+
import tiktoken, json, os
|
3 |
+
from langchain_core.output_parsers.format_instructions import JSON_FORMAT_INSTRUCTIONS
|
4 |
+
from transformers import AutoTokenizer
|
5 |
+
import GPUtil
|
6 |
+
import time
|
7 |
+
import psutil
|
8 |
+
import threading
|
9 |
+
import torch
|
10 |
+
|
11 |
+
|
12 |
+
def remove_colons_and_double_apostrophes(text):
|
13 |
+
return text.replace(":", "").replace("\"", "")
|
14 |
+
|
15 |
+
|
16 |
+
|
17 |
+
def count_tokens(string, vendor, model_name):
|
18 |
+
full_string = string + JSON_FORMAT_INSTRUCTIONS
|
19 |
+
|
20 |
+
def run_count(full_string, model_name):
|
21 |
+
# Ensure the encoding is obtained correctly.
|
22 |
+
encoding = tiktoken.encoding_for_model(model_name)
|
23 |
+
tokens = encoding.encode(full_string)
|
24 |
+
return len(tokens)
|
25 |
+
|
26 |
+
try:
|
27 |
+
if vendor == 'mistral':
|
28 |
+
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1")
|
29 |
+
|
30 |
+
tokens = tokenizer.tokenize(full_string)
|
31 |
+
return len(tokens)
|
32 |
+
|
33 |
+
else:
|
34 |
+
return run_count(full_string, model_name)
|
35 |
+
|
36 |
+
except Exception as e:
|
37 |
+
print(f"An error occurred: {e}")
|
38 |
+
return 0
|
39 |
+
|
40 |
+
|
41 |
+
|
42 |
+
class SystemLoadMonitor():
|
43 |
+
def __init__(self, logger) -> None:
|
44 |
+
self.monitoring_thread = None
|
45 |
+
self.logger = logger
|
46 |
+
self.gpu_usage = {'max_cpu_usage': 0, 'max_load': 0, 'max_vram_usage': 0, "max_ram_usage": 0, 'monitoring': True}
|
47 |
+
self.start_time = None
|
48 |
+
self.has_GPU = torch.cuda.is_available()
|
49 |
+
self.monitor_interval = 2
|
50 |
+
|
51 |
+
def start_monitoring_usage(self):
|
52 |
+
self.start_time = time.time()
|
53 |
+
self.monitoring_thread = threading.Thread(target=self.monitor_usage, args=(self.monitor_interval,))
|
54 |
+
self.monitoring_thread.start()
|
55 |
+
|
56 |
+
def monitor_usage(self, interval):
|
57 |
+
while self.gpu_usage['monitoring']:
|
58 |
+
# GPU monitoring
|
59 |
+
if self.has_GPU:
|
60 |
+
GPUs = GPUtil.getGPUs()
|
61 |
+
for gpu in GPUs:
|
62 |
+
self.gpu_usage['max_load'] = max(self.gpu_usage['max_load'], gpu.load)
|
63 |
+
# Convert memory usage to GB
|
64 |
+
memory_usage_gb = gpu.memoryUsed / 1024.0
|
65 |
+
self.gpu_usage['max_vram_usage'] = max(self.gpu_usage.get('max_vram_usage', 0), memory_usage_gb)
|
66 |
+
|
67 |
+
# RAM monitoring
|
68 |
+
ram_usage = psutil.virtual_memory().used / (1024.0 ** 3) # Get RAM usage in GB
|
69 |
+
self.gpu_usage['max_ram_usage'] = max(self.gpu_usage.get('max_ram_usage', 0), ram_usage)
|
70 |
+
|
71 |
+
# CPU monitoring
|
72 |
+
cpu_usage = psutil.cpu_percent(interval=None)
|
73 |
+
self.gpu_usage['max_cpu_usage'] = max(self.gpu_usage.get('max_cpu_usage', 0), cpu_usage)
|
74 |
+
time.sleep(interval)
|
75 |
+
|
76 |
+
def stop_monitoring_report_usage(self):
|
77 |
+
self.gpu_usage['monitoring'] = False
|
78 |
+
self.monitoring_thread.join()
|
79 |
+
elapsed_time = time.time() - self.start_time
|
80 |
+
self.logger.info(f"Inference Time: {round(elapsed_time,2)} seconds")
|
81 |
+
|
82 |
+
self.logger.info(f"Max CPU Usage: {round(self.gpu_usage['max_cpu_usage'],2)}%")
|
83 |
+
self.logger.info(f"Max RAM Usage: {round(self.gpu_usage['max_ram_usage'],2)}GB")
|
84 |
+
|
85 |
+
if self.has_GPU:
|
86 |
+
self.logger.info(f"Max GPU Load: {round(self.gpu_usage['max_load']*100,2)}%")
|
87 |
+
self.logger.info(f"Max GPU Memory Usage: {round(self.gpu_usage['max_vram_usage'],2)}GB")
|
88 |
+
|
89 |
+
|
90 |
+
|
91 |
+
|
92 |
+
|
93 |
+
|
94 |
+
|
95 |
+
|
96 |
+
|
97 |
+
|
98 |
+
|
99 |
+
|
100 |
+
|
101 |
+
|
102 |
+
|
vouchervision/utils_LLM_JSON_validation.py
ADDED
@@ -0,0 +1,170 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
|
3 |
+
|
4 |
+
def validate_and_align_JSON_keys_with_template(data, JSON_dict_structure):
|
5 |
+
data = validate_JSON(data)
|
6 |
+
if data is None:
|
7 |
+
return None
|
8 |
+
else:
|
9 |
+
# Make sure that there are no literal list [] objects in the dict
|
10 |
+
for key, value in data.items():
|
11 |
+
if value is None:
|
12 |
+
data[key] = ''
|
13 |
+
elif isinstance(value, str):
|
14 |
+
if value.lower() in ['unknown', 'not provided', 'missing', 'na', 'none', 'n/a', 'null',
|
15 |
+
'not provided in the text', 'not found in the text',
|
16 |
+
'not in the text', 'not provided', 'not found',
|
17 |
+
'not provided in the OCR', 'not found in the OCR',
|
18 |
+
'not in the OCR',
|
19 |
+
'not provided in the OCR text', 'not found in the OCR text',
|
20 |
+
"not specified in the given text.",
|
21 |
+
"not specified in the given text",
|
22 |
+
"not specified in the text.",
|
23 |
+
"not specified in the text",
|
24 |
+
"not specified in text.",
|
25 |
+
"not specified in text",
|
26 |
+
"not specified in OCR",
|
27 |
+
"not specified",
|
28 |
+
'not in the OCR text',
|
29 |
+
'n/a n/a','n/a, n/a',
|
30 |
+
'n/a, n/a, n/a','n/a n/a, n/a','n/a, n/a n/a','n/a n/a n/a',
|
31 |
+
'n/a, n/a, n/a, n/a','n/a n/a n/a n/a','n/a n/a, n/a, n/a','n/a, n/a n/a, n/a','n/a, n/a, n/a n/a',
|
32 |
+
'n/a n/a n/a, n/a','n/a, n/a n/a n/a',
|
33 |
+
'n/a n/a, n/a n/a',]:
|
34 |
+
data[key] = ''
|
35 |
+
elif isinstance(value, list):
|
36 |
+
# Join the list elements into a single string
|
37 |
+
data[key] = ', '.join(str(item) for item in value)
|
38 |
+
|
39 |
+
### align the keys with the template
|
40 |
+
# Create a new dictionary with the same order of keys as JSON_dict_structure
|
41 |
+
ordered_data = {}
|
42 |
+
|
43 |
+
# This will catch cases where the LLM JSON case does not match the required JSON key's case
|
44 |
+
for key in JSON_dict_structure:
|
45 |
+
truth_key_lower = key.lower()
|
46 |
+
|
47 |
+
llm_value = str(data.get(key, ''))
|
48 |
+
if not llm_value:
|
49 |
+
llm_value = str(data.get(truth_key_lower, ''))
|
50 |
+
|
51 |
+
# Copy the value from data if it exists, else use an empty string
|
52 |
+
ordered_data[key] = llm_value
|
53 |
+
|
54 |
+
return ordered_data
|
55 |
+
|
56 |
+
|
57 |
+
|
58 |
+
|
59 |
+
def validate_JSON(data):
|
60 |
+
if isinstance(data, dict):
|
61 |
+
return data
|
62 |
+
else:
|
63 |
+
if isinstance(data, list):
|
64 |
+
data = data[0]
|
65 |
+
try:
|
66 |
+
json_candidate = json.loads(data) # decoding the JSON data
|
67 |
+
if isinstance(json_candidate, list):
|
68 |
+
json_candidate = json_candidate[0]
|
69 |
+
|
70 |
+
if isinstance(json_candidate, dict):
|
71 |
+
data = json_candidate
|
72 |
+
return data
|
73 |
+
else:
|
74 |
+
return None
|
75 |
+
except:
|
76 |
+
return None
|
77 |
+
|
78 |
+
|
79 |
+
|
80 |
+
|
81 |
+
#### A manual method for pulling a JSON object out of a verbose LLM response.
|
82 |
+
#### It's messy butr works well enough. Only use if JSONparsing is not available
|
83 |
+
def extract_json_dict_manual(text):
|
84 |
+
text = text.strip().replace("\n", "").replace("\\\\", "")
|
85 |
+
# Find the first opening curly brace
|
86 |
+
start_index = text.find('{')
|
87 |
+
# Find the last closing curly brace
|
88 |
+
end_index = text.rfind('}') + 1
|
89 |
+
|
90 |
+
text = text[start_index:end_index]
|
91 |
+
|
92 |
+
# Find the first opening curly brace
|
93 |
+
start_index = text.find('{')
|
94 |
+
# Find the last closing curly brace
|
95 |
+
end_index = text.rfind('}') + 1
|
96 |
+
|
97 |
+
# Check and remove backslash immediately after the opening curly brace
|
98 |
+
if text[start_index + 1] == "\\":
|
99 |
+
text = text[:start_index + 1] + text[start_index + 2:]
|
100 |
+
|
101 |
+
# Find the first opening curly brace
|
102 |
+
start_index = text.find('{')
|
103 |
+
# Find the last closing curly brace
|
104 |
+
end_index = text.rfind('}') + 1
|
105 |
+
|
106 |
+
# print(text[end_index-2])
|
107 |
+
if text[end_index-2] == "\\":
|
108 |
+
text = text[:end_index-2] + text[end_index-1]
|
109 |
+
else:
|
110 |
+
text = text
|
111 |
+
|
112 |
+
# Find the first opening curly brace
|
113 |
+
start_index = text.find('{')
|
114 |
+
# Find the last closing curly brace
|
115 |
+
end_index = text.rfind('}') + 1
|
116 |
+
|
117 |
+
# print(text[end_index-2])
|
118 |
+
if text[end_index-2] == "\\":
|
119 |
+
text = text[:end_index-3] + text[end_index-1]
|
120 |
+
else:
|
121 |
+
text = text
|
122 |
+
|
123 |
+
# Trim whitespace
|
124 |
+
json_str = text
|
125 |
+
|
126 |
+
# Print the JSON string for inspection
|
127 |
+
# print("Extracted JSON string:", json_str)
|
128 |
+
|
129 |
+
# Convert JSON string to Python dictionary
|
130 |
+
try:
|
131 |
+
# If necessary, replace newline characters
|
132 |
+
# json_str = json_str.replace('\n', '\\n')
|
133 |
+
|
134 |
+
json_dict = json.loads(json_str)
|
135 |
+
return json_dict
|
136 |
+
except Exception as e:
|
137 |
+
print("Error parsing JSON:", e)
|
138 |
+
return None
|
139 |
+
'''
|
140 |
+
if __name__ == "__main__":
|
141 |
+
tex = """Extracted JSON string: {"catalogNumber": "MPU395640",
|
142 |
+
"order": "Monimizeae",
|
143 |
+
"family": "Monimiaceae",
|
144 |
+
"scientificName": "Hedycarya parvifolia",
|
145 |
+
"scientificNameAuthorship": "Perkins & Schltr.",
|
146 |
+
"genus": "Hedycarya",
|
147 |
+
"subgenus": null,
|
148 |
+
"specificEpithet": "parvifolia",
|
149 |
+
"infraspecificEpithet": null,
|
150 |
+
"identifiedBy": null,
|
151 |
+
"recordedBy": "R. Pouteau & J. Munzinger",
|
152 |
+
"recordNumber": "RP 1",
|
153 |
+
"verbatimEventDate": "26-2-2013",
|
154 |
+
"eventDate": "2013-02-26",
|
155 |
+
"habitat": "Ultramafique Long. 165 ° 52'21 E, Lat. 21 ° 29'19 S, Maquis",
|
156 |
+
"occurrenceRemarks": "Fruit Arbuste, Ht. 1,5 mètre (s), Fruits verts (immatures) à noirs (matures), Coll. R. Pouteau & J. Munzinger N° RP 1, Dupl. P - MPU, RECOLNAT, Herbier IRD de Nouvelle - Calédonie Poutan 1, Golden Thread, Alt. 818 mètre, Diam. 2 cm, Date 26-2-2013",
|
157 |
+
"country": "Nouvelle Calédonie",
|
158 |
+
"stateProvince": null,
|
159 |
+
"county": null,
|
160 |
+
"municipality": null,
|
161 |
+
"locality": "Antenne du plateau de Bouakaine",
|
162 |
+
"degreeOfEstablishment": "cultivated",
|
163 |
+
"decimalLatitude": -21.488611,
|
164 |
+
"decimalLongitude": 165.8725,
|
165 |
+
"verbatimCoordinates": "Long. 165 ° 52'21 E, Lat. 21 ° 29'19 S",
|
166 |
+
"minimumElevationInMeters": 818,
|
167 |
+
"maximumElevationInMeters": 818
|
168 |
+
\\}"""
|
169 |
+
new = extract_json_dict(tex)
|
170 |
+
'''
|
vouchervision/utils_VoucherVision.py
CHANGED
@@ -1,26 +1,23 @@
|
|
1 |
import openai
|
2 |
-
import os,
|
3 |
import openpyxl
|
4 |
from openpyxl import Workbook, load_workbook
|
5 |
-
import
|
6 |
-
from
|
7 |
-
from
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
from
|
17 |
-
from
|
18 |
-
from
|
19 |
-
from
|
20 |
-
|
21 |
-
# from LLM_Falcon import OCR_to_dict_Falcon
|
22 |
-
from prompts import PROMPT_UMICH_skeleton_all_asia, PROMPT_OCR_Organized, PROMPT_UMICH_skeleton_all_asia_GPT4, PROMPT_OCR_Organized_GPT4, PROMPT_JSON
|
23 |
-
from prompt_catalog import PromptCatalog
|
24 |
'''
|
25 |
* For the prefix_removal, the image names have 'MICH-V-' prior to the barcode, so that is used for matching
|
26 |
but removed for output.
|
@@ -29,24 +26,11 @@ from prompt_catalog import PromptCatalog
|
|
29 |
- Look for ####################### Catalog Number pre-defined
|
30 |
'''
|
31 |
|
32 |
-
'''
|
33 |
-
Prior to StructuredOutputParser:
|
34 |
-
response = openai.ChatCompletion.create(
|
35 |
-
model=MODEL,
|
36 |
-
temperature = 0,
|
37 |
-
messages=[
|
38 |
-
{"role": "system", "content": "You are a helpful assistant acting as a transcription expert and your job is to transcribe herbarium specimen labels based on OCR data and reformat it to meet Darwin Core Archive Standards into a Python dictionary based on certain rules."},
|
39 |
-
{"role": "user", "content": prompt},
|
40 |
-
],
|
41 |
-
max_tokens=2048,
|
42 |
-
)
|
43 |
-
# print the model's response
|
44 |
-
return response.choices[0].message['content']
|
45 |
-
'''
|
46 |
|
|
|
47 |
class VoucherVision():
|
48 |
|
49 |
-
def __init__(self, cfg, logger, dir_home, path_custom_prompts, Project, Dirs):
|
50 |
self.cfg = cfg
|
51 |
self.logger = logger
|
52 |
self.dir_home = dir_home
|
@@ -55,7 +39,12 @@ class VoucherVision():
|
|
55 |
self.Dirs = Dirs
|
56 |
self.headers = None
|
57 |
self.prompt_version = None
|
58 |
-
self.
|
|
|
|
|
|
|
|
|
|
|
59 |
|
60 |
self.set_API_keys()
|
61 |
self.setup()
|
@@ -78,15 +67,25 @@ class VoucherVision():
|
|
78 |
self.prompt_version0 = self.cfg['leafmachine']['project']['prompt_version']
|
79 |
self.use_domain_knowledge = self.cfg['leafmachine']['project']['use_domain_knowledge']
|
80 |
|
81 |
-
self.catalog_name_options = ["Catalog Number", "catalog_number"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
82 |
|
83 |
-
self.
|
84 |
|
85 |
self.map_prompt_versions()
|
86 |
self.map_dir_labels()
|
87 |
self.map_API_options()
|
88 |
-
self.init_embeddings()
|
89 |
self.init_transcription_xlsx()
|
|
|
90 |
|
91 |
'''Logging'''
|
92 |
self.logger.info(f'Transcribing dataset --- {self.dir_labels}')
|
@@ -98,41 +97,31 @@ class VoucherVision():
|
|
98 |
self.logger.info(f' Model name passed to API --> {self.model_name}')
|
99 |
self.logger.info(f' API access token is found in PRIVATE_DATA.yaml --> {self.has_key}')
|
100 |
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
# raise Exception(f"Unsupported LLM: {self.chat_version}. Requires one of: {supported_LLMs}")
|
113 |
-
|
114 |
-
# self.version_name, self.is_azure, self.model_name, self.has_key = version_mapping[self.chat_version]
|
115 |
def map_API_options(self):
|
116 |
-
self.chat_version = self.cfg['leafmachine']['LLM_version']
|
117 |
-
|
118 |
-
# Assuming you have set your environment variables for each key like 'OPENAI_API_KEY', 'AZURE_API_KEY', 'PALM_API_KEY'
|
119 |
-
openai_api_key = os.getenv('OPENAI_API_KEY')
|
120 |
-
azure_api_key = os.getenv('AZURE_API_KEY')
|
121 |
-
palm_api_key = os.getenv('PALM_API_KEY')
|
122 |
-
|
123 |
-
version_mapping = {
|
124 |
-
'GPT 4': ('OpenAI GPT 4', False, 'GPT_4', bool(openai_api_key)),
|
125 |
-
'GPT 3.5': ('OpenAI GPT 3.5', False, 'GPT_3_5', bool(openai_api_key)),
|
126 |
-
'Azure GPT 3.5': ('(Azure) OpenAI GPT 3.5', True, 'Azure_GPT_3_5', bool(azure_api_key)),
|
127 |
-
'Azure GPT 4': ('(Azure) OpenAI GPT 4', True, 'Azure_GPT_4', bool(azure_api_key)),
|
128 |
-
'PaLM 2': ('Google PaLM 2', None, None, bool(palm_api_key))
|
129 |
-
}
|
130 |
|
131 |
-
|
132 |
-
|
133 |
-
|
|
|
134 |
|
135 |
-
|
|
|
|
|
|
|
|
|
|
|
136 |
|
137 |
def map_prompt_versions(self):
|
138 |
self.prompt_version_map = {
|
@@ -149,13 +138,13 @@ class VoucherVision():
|
|
149 |
def is_in_prompt_version_map(self, value):
|
150 |
return value in self.prompt_version_map.values()
|
151 |
|
152 |
-
def init_embeddings(self):
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
-
|
159 |
|
160 |
def map_dir_labels(self):
|
161 |
if self.cfg['leafmachine']['use_RGB_label_images']:
|
@@ -176,53 +165,41 @@ class VoucherVision():
|
|
176 |
|
177 |
def generate_xlsx_headers(self):
|
178 |
# Extract headers from the 'Dictionary' keys in the JSON template rules
|
179 |
-
xlsx_headers = list(self.rules_config_json['rules']["Dictionary"].keys())
|
|
|
180 |
xlsx_headers = xlsx_headers + self.utility_headers
|
181 |
return xlsx_headers
|
182 |
|
183 |
def init_transcription_xlsx(self):
|
184 |
-
self.HEADERS_v1_n22 = ["Catalog Number","Genus","Species","subspecies","variety","forma","Country","State","County","Locality Name","Min Elevation","Max Elevation","Elevation Units","Verbatim Coordinates","Datum","Cultivated","Habitat","Collectors","Collector Number","Verbatim Date","Date","End Date"]
|
185 |
-
self.HEADERS_v2_n26 = ["catalog_number","genus","species","subspecies","variety","forma","country","state","county","locality_name","min_elevation","max_elevation","elevation_units","verbatim_coordinates","decimal_coordinates","datum","cultivated","habitat","plant_description","collectors","collector_number","determined_by","multiple_names","verbatim_date","date","end_date"]
|
186 |
-
self.HEADERS_v1_n22 = self.HEADERS_v1_n22 + self.utility_headers
|
187 |
-
self.HEADERS_v2_n26 = self.HEADERS_v2_n26 + self.utility_headers
|
188 |
# Initialize output file
|
189 |
self.path_transcription = os.path.join(self.Dirs.transcription,"transcribed.xlsx")
|
190 |
|
191 |
-
if self.prompt_version in ['prompt_v2_json_rules','prompt_v2_palm2']:
|
192 |
-
self.headers = self.HEADERS_v2_n26
|
193 |
-
self.headers_used = 'HEADERS_v2_n26'
|
194 |
|
195 |
-
elif self.prompt_version in ['prompt_v1_verbose', 'prompt_v1_verbose_noDomainKnowledge','prompt_v1_palm2', 'prompt_v1_palm2_noDomainKnowledge']:
|
196 |
-
self.headers = self.HEADERS_v1_n22
|
197 |
-
self.headers_used = 'HEADERS_v1_n22'
|
198 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
199 |
else:
|
200 |
-
|
201 |
-
|
202 |
-
self.rules_config_json = self.load_rules_config()
|
203 |
-
# Generate the headers from the configuration
|
204 |
-
self.headers = self.generate_xlsx_headers()
|
205 |
-
# Set the headers used to the dynamically generated headers
|
206 |
-
self.headers_used = 'CUSTOM'
|
207 |
-
else:
|
208 |
-
# If it's a predefined prompt, raise an exception as we don't have further instructions
|
209 |
-
raise ValueError("Predefined prompt is not handled in this context.")
|
210 |
|
211 |
self.create_or_load_excel_with_headers(os.path.join(self.Dirs.transcription,"transcribed.xlsx"), self.headers)
|
212 |
|
213 |
-
|
214 |
-
def pick_model(self, vendor, nt):
|
215 |
-
if vendor == 'GPT_3_5':
|
216 |
-
if nt > 6000:
|
217 |
-
return "gpt-3.5-turbo-16k-0613", True
|
218 |
-
else:
|
219 |
-
return "gpt-3.5-turbo", False
|
220 |
-
if vendor == 'GPT_4':
|
221 |
-
return "gpt-4", False
|
222 |
-
if vendor == 'Azure_GPT_3_5':
|
223 |
-
return "gpt-35-turbo", False
|
224 |
-
if vendor == 'Azure_GPT_4':
|
225 |
-
return "gpt-4", False
|
226 |
|
227 |
def create_or_load_excel_with_headers(self, file_path, headers, show_head=False):
|
228 |
output_dir_names = ['Archival_Components', 'Config_File', 'Cropped_Images', 'Logs', 'Original_Images', 'Transcription']
|
@@ -324,7 +301,14 @@ class VoucherVision():
|
|
324 |
|
325 |
|
326 |
|
327 |
-
def add_data_to_excel_from_response(self, path_transcription, response, filename_without_extension, path_to_crop, path_to_content, path_to_helper, nt_in, nt_out):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
328 |
wb = openpyxl.load_workbook(path_transcription)
|
329 |
sheet = wb.active
|
330 |
|
@@ -385,68 +369,182 @@ class VoucherVision():
|
|
385 |
sheet.cell(row=next_row, column=i, value=nt_in)
|
386 |
elif header.value == "tokens_out":
|
387 |
sheet.cell(row=next_row, column=i, value=nt_out)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
388 |
# save the workbook
|
389 |
wb.save(path_transcription)
|
|
|
390 |
|
391 |
-
def
|
392 |
-
|
393 |
-
|
394 |
-
azure_api_key = os.getenv('AZURE_API_KEY')
|
395 |
-
azure_api_base = os.getenv('AZURE_API_BASE')
|
396 |
-
azure_organization = os.getenv('AZURE_ORGANIZATION')
|
397 |
-
azure_api_type = os.getenv('AZURE_API_TYPE')
|
398 |
-
azure_deployment_name = os.getenv('AZURE_DEPLOYMENT_NAME')
|
399 |
-
|
400 |
-
# Check if all required Azure configurations are present
|
401 |
-
if azure_api_version and azure_api_key and azure_api_base and azure_organization and azure_api_type and azure_deployment_name:
|
402 |
-
self.llm = AzureChatOpenAI(
|
403 |
-
deployment_name=azure_deployment_name,
|
404 |
-
openai_api_version=azure_api_version,
|
405 |
-
openai_api_key=azure_api_key,
|
406 |
-
openai_api_base=azure_api_base,
|
407 |
-
openai_organization=azure_organization,
|
408 |
-
openai_api_type=azure_api_type
|
409 |
-
)
|
410 |
else:
|
411 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
412 |
|
413 |
def set_API_keys(self):
|
414 |
-
|
415 |
-
|
416 |
-
|
417 |
-
|
418 |
-
|
419 |
-
|
420 |
-
|
421 |
-
|
422 |
-
|
423 |
-
|
424 |
-
|
425 |
-
|
426 |
-
|
427 |
-
|
428 |
-
|
429 |
-
|
430 |
-
|
431 |
-
|
432 |
-
|
433 |
-
|
434 |
-
self.
|
435 |
-
|
436 |
-
|
437 |
-
self.
|
438 |
-
|
439 |
-
|
440 |
-
|
441 |
-
|
442 |
-
|
443 |
-
|
444 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
445 |
|
|
|
|
|
|
|
446 |
|
447 |
-
def initialize_embeddings(self):
|
448 |
-
'''Loading embedding search __init__(self, db_name, path_domain_knowledge, logger, build_new_db=False, model_name="hkunlp/instructor-xl", device="cuda")'''
|
449 |
-
self.Voucher_Vision_Embedding = VoucherVisionEmbedding(self.db_name, self.path_domain_knowledge, logger=self.logger, build_new_db=self.build_new_db)
|
450 |
|
451 |
def clean_catalog_number(self, data, filename_without_extension):
|
452 |
#Cleans up the catalog number in data if it's a dict
|
@@ -470,7 +568,7 @@ class VoucherVision():
|
|
470 |
if self.headers_used == 'HEADERS_v1_n22':
|
471 |
return modify_catalog_key("Catalog Number", filename_without_extension, data)
|
472 |
elif self.headers_used in ['HEADERS_v2_n26', 'CUSTOM']:
|
473 |
-
return modify_catalog_key("
|
474 |
else:
|
475 |
raise ValueError("Invalid headers used.")
|
476 |
else:
|
@@ -484,23 +582,19 @@ class VoucherVision():
|
|
484 |
data = json.dumps(data, indent=4, sort_keys=False)
|
485 |
txt_file.write(data)
|
486 |
|
487 |
-
|
488 |
-
|
|
|
489 |
|
|
|
490 |
def remove_non_numbers(self, s):
|
491 |
return ''.join([char for char in s if char.isdigit()])
|
492 |
|
|
|
493 |
def create_null_row(self, filename_without_extension, path_to_crop, path_to_content, path_to_helper):
|
494 |
json_dict = {header: '' for header in self.headers}
|
495 |
for header, value in json_dict.items():
|
496 |
-
if header
|
497 |
-
if self.prefix_removal:
|
498 |
-
json_dict[header] = filename_without_extension.replace(self.prefix_removal, "")
|
499 |
-
if self.suffix_removal:
|
500 |
-
json_dict[header] = filename_without_extension.replace(self.suffix_removal, "")
|
501 |
-
if self.catalog_numerical_only:
|
502 |
-
json_dict[header] = self.remove_non_numbers(json_dict[header])
|
503 |
-
elif header == "path_to_crop":
|
504 |
json_dict[header] = path_to_crop
|
505 |
elif header == "path_to_original":
|
506 |
fname = os.path.basename(path_to_crop)
|
@@ -511,231 +605,231 @@ class VoucherVision():
|
|
511 |
json_dict[header] = path_to_content
|
512 |
elif header == "path_to_helper":
|
513 |
json_dict[header] = path_to_helper
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
514 |
return json_dict
|
|
|
515 |
|
|
|
|
|
|
|
|
|
|
|
|
|
516 |
|
517 |
-
|
518 |
-
|
519 |
-
self.
|
|
|
520 |
|
521 |
-
|
522 |
-
|
523 |
-
|
524 |
-
# Find a similar example from the domain knowledge
|
525 |
-
domain_knowledge_example = self.Voucher_Vision_Embedding.query_db(self.OCR, 1)
|
526 |
-
similarity= self.Voucher_Vision_Embedding.get_similarity()
|
527 |
|
528 |
-
|
529 |
-
|
530 |
|
531 |
-
|
532 |
-
|
533 |
-
|
534 |
|
535 |
-
|
536 |
-
|
537 |
-
else:
|
538 |
-
prompt, n_fields, xlsx_headers = Catalog.prompt_v2_custom(self.path_custom_prompts, OCR=self.OCR)
|
539 |
-
|
540 |
|
|
|
|
|
541 |
|
542 |
-
|
543 |
-
|
|
|
|
|
544 |
|
545 |
-
|
546 |
-
|
|
|
|
|
|
|
|
|
547 |
|
548 |
-
|
|
|
|
|
549 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
550 |
|
551 |
-
|
552 |
-
|
553 |
-
|
554 |
-
# domain_knowledge_example = self.Voucher_Vision_Embedding.query_db(self.OCR, 1)
|
555 |
-
# similarity= self.Voucher_Vision_Embedding.get_similarity()
|
556 |
|
557 |
-
|
|
|
558 |
|
559 |
-
|
560 |
-
|
|
|
561 |
|
562 |
-
|
563 |
-
|
564 |
|
565 |
-
|
|
|
|
|
|
|
566 |
|
567 |
-
|
568 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
569 |
|
570 |
-
|
571 |
-
|
572 |
-
# elif opt == 'helper':
|
573 |
-
# prompt = PROMPT_OCR_Organized_GPT4(self.OCR)
|
574 |
-
# nt = num_tokens_from_string(prompt, "cl100k_base")
|
575 |
|
576 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
577 |
|
578 |
-
|
579 |
-
|
580 |
-
|
|
|
|
|
|
|
|
|
581 |
|
582 |
-
|
583 |
|
|
|
584 |
|
585 |
-
|
586 |
-
|
587 |
-
|
588 |
-
final_JSON_response =
|
589 |
-
|
590 |
-
|
591 |
-
for i, path_to_crop in enumerate(self.img_paths):
|
592 |
-
if progress_report is not None:
|
593 |
-
progress_report.update_batch(f"Working on image {i+1} of {len(self.img_paths)}")
|
594 |
|
595 |
-
|
596 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
597 |
else:
|
598 |
-
|
599 |
-
|
600 |
-
|
601 |
-
|
602 |
-
|
603 |
-
|
604 |
-
|
605 |
-
|
606 |
-
|
607 |
-
|
608 |
-
|
609 |
-
|
610 |
-
self.write_json_to_file(txt_file_path_OCR, {"OCR":self.OCR})
|
611 |
-
self.write_json_to_file(txt_file_path_OCR_bounds, {"OCR_Bounds":self.bounds})
|
612 |
-
self.overlay_image.save(jpg_file_path_OCR_helper)
|
613 |
-
|
614 |
-
# Setup Dict
|
615 |
-
MODEL, prompt, use_long_form, n_fields, xlsx_headers, nt_in = self.setup_GPT(self.prompt_version, gpt)
|
616 |
-
|
617 |
-
if is_azure:
|
618 |
-
self.llm.deployment_name = MODEL
|
619 |
-
else:
|
620 |
-
self.llm = None
|
621 |
-
|
622 |
-
# Send OCR to chatGPT and return formatted dictonary
|
623 |
-
if use_long_form:
|
624 |
-
response_candidate = OCR_to_dict_16k(is_azure, self.logger, MODEL, prompt, self.llm, self.prompt_version)
|
625 |
-
nt_out = num_tokens_from_string(response_candidate, "cl100k_base")
|
626 |
-
else:
|
627 |
-
response_candidate = OCR_to_dict(is_azure, self.logger, MODEL, prompt, self.llm, self.prompt_version)
|
628 |
-
nt_out = num_tokens_from_string(response_candidate, "cl100k_base")
|
629 |
-
else:
|
630 |
-
response_candidate = None
|
631 |
-
nt_out = 0
|
632 |
-
|
633 |
-
total_tokens_in += nt_in
|
634 |
-
total_tokens_out += nt_out
|
635 |
-
|
636 |
-
final_JSON_response0 = self.save_json_and_xlsx(response_candidate, filename_without_extension, path_to_crop, txt_file_path, jpg_file_path_OCR_helper, nt_in, nt_out)
|
637 |
-
if response_candidate is not None:
|
638 |
-
final_JSON_response = final_JSON_response0
|
639 |
-
|
640 |
-
self.logger.info(f'Formatted JSON\n{final_JSON_response}')
|
641 |
-
self.logger.info(f'Finished {MODEL} API calls\n')
|
642 |
-
|
643 |
-
if progress_report is not None:
|
644 |
-
progress_report.reset_batch(f"Batch Complete")
|
645 |
-
try:
|
646 |
-
final_JSON_response = json.loads(final_JSON_response.strip('```').replace('json\n', '', 1).replace('json', '', 1))
|
647 |
-
except:
|
648 |
-
pass
|
649 |
-
return final_JSON_response, total_tokens_in, total_tokens_out
|
650 |
|
651 |
-
|
|
|
|
|
652 |
|
653 |
-
|
654 |
-
|
655 |
-
total_tokens_out = 0
|
656 |
-
final_JSON_response = None
|
657 |
if progress_report is not None:
|
658 |
progress_report.set_n_batches(len(self.img_paths))
|
659 |
-
for i, path_to_crop in enumerate(self.img_paths):
|
660 |
-
if progress_report is not None:
|
661 |
-
progress_report.update_batch(f"Working on image {i+1} of {len(self.img_paths)}")
|
662 |
-
if os.path.basename(path_to_crop) in self.completed_specimens:
|
663 |
-
self.logger.info(f'[Skipping] specimen {os.path.basename(path_to_crop)} already processed')
|
664 |
-
else:
|
665 |
-
filename_without_extension, txt_file_path, txt_file_path_OCR, txt_file_path_OCR_bounds, jpg_file_path_OCR_helper = self.generate_paths(path_to_crop, i)
|
666 |
-
|
667 |
-
# Use Google Vision API to get OCR
|
668 |
-
self.OCR, self.bounds, self.text_to_box_mapping = detect_text(path_to_crop, self.client)
|
669 |
-
if len(self.OCR) > 0:
|
670 |
-
self.logger.info(f'Working on {i+1}/{len(self.img_paths)} --- Starting OCR')
|
671 |
-
self.OCR = self.OCR.replace("\'", "Minutes").replace('\"', "Seconds")
|
672 |
-
self.logger.info(f'Working on {i+1}/{len(self.img_paths)} --- Finished OCR')
|
673 |
-
|
674 |
-
self.logger.info(f'Working on {i+1}/{len(self.img_paths)} --- Creating OCR Overlay Image')
|
675 |
-
self.overlay_image = overlay_boxes_on_image(path_to_crop, self.bounds, self.cfg['leafmachine']['do_create_OCR_helper_image'])
|
676 |
-
self.logger.info(f'Working on {i+1}/{len(self.img_paths)} --- Saved OCR Overlay Image')
|
677 |
-
|
678 |
-
self.write_json_to_file(txt_file_path_OCR, {"OCR":self.OCR})
|
679 |
-
self.write_json_to_file(txt_file_path_OCR_bounds, {"OCR_Bounds":self.bounds})
|
680 |
-
self.overlay_image.save(jpg_file_path_OCR_helper)
|
681 |
-
|
682 |
-
# Send OCR to chatGPT and return formatted dictonary
|
683 |
-
response_candidate, nt_in = OCR_to_dict_PaLM(self.logger, self.OCR, self.prompt_version, self.Voucher_Vision_Embedding)
|
684 |
-
nt_out = num_tokens_from_string(response_candidate, "cl100k_base")
|
685 |
-
|
686 |
-
else:
|
687 |
-
response_candidate = None
|
688 |
-
nt_out = 0
|
689 |
-
|
690 |
-
total_tokens_in += nt_in
|
691 |
-
total_tokens_out += nt_out
|
692 |
|
693 |
-
final_JSON_response0 = self.save_json_and_xlsx(response_candidate, filename_without_extension, path_to_crop, txt_file_path, jpg_file_path_OCR_helper, nt_in, nt_out)
|
694 |
-
if response_candidate is not None:
|
695 |
-
final_JSON_response = final_JSON_response0
|
696 |
-
self.logger.info(f'Formatted JSON\n{final_JSON_response}')
|
697 |
-
self.logger.info(f'Finished PaLM 2 API calls\n')
|
698 |
|
|
|
699 |
if progress_report is not None:
|
700 |
-
progress_report.
|
701 |
-
return final_JSON_response, total_tokens_in, total_tokens_out
|
702 |
|
703 |
|
704 |
-
|
705 |
-
|
706 |
-
|
707 |
-
|
708 |
-
|
709 |
-
|
710 |
-
else:
|
711 |
-
filename_without_extension = os.path.splitext(os.path.basename(path_to_crop))[0]
|
712 |
-
txt_file_path = os.path.join(self.Dirs.transcription_ind, filename_without_extension + '.json')
|
713 |
-
txt_file_path_helper = os.path.join(self.Dirs.transcription_ind_helper, filename_without_extension + '.json')
|
714 |
-
self.logger.info(f'Working on {i+1}/{len(self.img_paths)} --- {filename_without_extension}')
|
715 |
-
|
716 |
-
# Use Google Vision API to get OCR
|
717 |
-
self.OCR, self.bounds, self.text_to_box_mapping = detect_text(path_to_crop)
|
718 |
-
if len(self.OCR) > 0:
|
719 |
-
self.OCR = self.OCR.replace("\'", "Minutes").replace('\"', "Seconds")
|
720 |
-
|
721 |
-
# Send OCR to Falcon and return formatted dictionary
|
722 |
-
response = OCR_to_dict_Falcon(self.logger, self.OCR, self.Voucher_Vision_Embedding)
|
723 |
-
# response_helper = OCR_to_helper_Falcon(self.logger, OCR) # Assuming you have a similar helper function for Falcon
|
724 |
-
response_helper = None
|
725 |
-
|
726 |
-
self.logger.info(f'Finished Falcon API calls\n')
|
727 |
-
else:
|
728 |
-
response = None
|
729 |
|
730 |
-
|
731 |
-
|
732 |
-
|
733 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
734 |
|
735 |
-
|
736 |
-
|
737 |
-
|
738 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
739 |
|
740 |
def generate_paths(self, path_to_crop, i):
|
741 |
filename_without_extension = os.path.splitext(os.path.basename(path_to_crop))[0]
|
@@ -748,49 +842,58 @@ class VoucherVision():
|
|
748 |
|
749 |
return filename_without_extension, txt_file_path, txt_file_path_OCR, txt_file_path_OCR_bounds, jpg_file_path_OCR_helper
|
750 |
|
751 |
-
|
|
|
752 |
if response is None:
|
753 |
-
response = self.
|
|
|
|
|
754 |
self.write_json_to_file(txt_file_path, response)
|
755 |
|
756 |
# Then add the null info to the spreadsheet
|
757 |
response_null = self.create_null_row(filename_without_extension, path_to_crop, txt_file_path, jpg_file_path_OCR_helper)
|
758 |
-
self.add_data_to_excel_from_response(self.path_transcription, response_null, filename_without_extension, path_to_crop, txt_file_path, jpg_file_path_OCR_helper, nt_in=0, nt_out=0)
|
759 |
|
760 |
### Set completed JSON
|
761 |
else:
|
762 |
response = self.clean_catalog_number(response, filename_without_extension)
|
763 |
self.write_json_to_file(txt_file_path, response)
|
764 |
# add to the xlsx file
|
765 |
-
self.add_data_to_excel_from_response(self.path_transcription, response, filename_without_extension, path_to_crop, txt_file_path, jpg_file_path_OCR_helper, nt_in, nt_out)
|
766 |
return response
|
767 |
|
768 |
-
|
769 |
-
|
770 |
-
|
|
|
|
|
|
|
771 |
try:
|
772 |
-
if
|
773 |
-
|
774 |
-
|
775 |
-
|
776 |
-
|
777 |
-
|
778 |
-
|
779 |
-
|
780 |
-
|
781 |
-
except:
|
782 |
if progress_report is not None:
|
783 |
progress_report.reset_batch(f"Batch Failed")
|
784 |
-
self.
|
785 |
-
for handler in self.logger.handlers[:]:
|
786 |
-
handler.close()
|
787 |
-
self.logger.removeHandler(handler)
|
788 |
raise
|
789 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
790 |
def process_specimen_batch_OCR_test(self, path_to_crop):
|
791 |
for img_filename in os.listdir(path_to_crop):
|
792 |
img_path = os.path.join(path_to_crop, img_filename)
|
793 |
-
self.OCR, self.bounds, self.text_to_box_mapping = detect_text(img_path
|
794 |
|
795 |
|
796 |
|
|
|
1 |
import openai
|
2 |
+
import os, json, glob, shutil, yaml, torch, logging
|
3 |
import openpyxl
|
4 |
from openpyxl import Workbook, load_workbook
|
5 |
+
import vertexai
|
6 |
+
from transformers import TrOCRProcessor, VisionEncoderDecoderModel
|
7 |
+
from langchain_openai import AzureChatOpenAI
|
8 |
+
from OCR_google_cloud_vision import OCRGoogle
|
9 |
+
|
10 |
+
from vouchervision.LLM_OpenAI import OpenAIHandler
|
11 |
+
from vouchervision.LLM_GooglePalm2 import GooglePalm2Handler
|
12 |
+
from vouchervision.LLM_GoogleGemini import GoogleGeminiHandler
|
13 |
+
from vouchervision.LLM_MistralAI import MistralHandler
|
14 |
+
from vouchervision.LLM_local_cpu_MistralAI import LocalCPUMistralHandler
|
15 |
+
from vouchervision.LLM_local_MistralAI import LocalMistralHandler
|
16 |
+
from vouchervision.utils_LLM import remove_colons_and_double_apostrophes
|
17 |
+
from vouchervision.prompt_catalog import PromptCatalog
|
18 |
+
from vouchervision.model_maps import ModelMaps
|
19 |
+
from vouchervision.general_utils import get_cfg_from_full_path
|
20 |
+
|
|
|
|
|
|
|
21 |
'''
|
22 |
* For the prefix_removal, the image names have 'MICH-V-' prior to the barcode, so that is used for matching
|
23 |
but removed for output.
|
|
|
26 |
- Look for ####################### Catalog Number pre-defined
|
27 |
'''
|
28 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
29 |
|
30 |
+
|
31 |
class VoucherVision():
|
32 |
|
33 |
+
def __init__(self, cfg, logger, dir_home, path_custom_prompts, Project, Dirs, is_hf):
|
34 |
self.cfg = cfg
|
35 |
self.logger = logger
|
36 |
self.dir_home = dir_home
|
|
|
39 |
self.Dirs = Dirs
|
40 |
self.headers = None
|
41 |
self.prompt_version = None
|
42 |
+
self.is_hf = is_hf
|
43 |
+
|
44 |
+
# self.trOCR_model_version = "microsoft/trocr-large-handwritten"
|
45 |
+
self.trOCR_model_version = "microsoft/trocr-base-handwritten"
|
46 |
+
self.trOCR_processor = None
|
47 |
+
self.trOCR_model = None
|
48 |
|
49 |
self.set_API_keys()
|
50 |
self.setup()
|
|
|
67 |
self.prompt_version0 = self.cfg['leafmachine']['project']['prompt_version']
|
68 |
self.use_domain_knowledge = self.cfg['leafmachine']['project']['use_domain_knowledge']
|
69 |
|
70 |
+
self.catalog_name_options = ["Catalog Number", "catalog_number", "catalogNumber"]
|
71 |
+
|
72 |
+
self.utility_headers = ["filename",
|
73 |
+
"WFO_override_OCR", "WFO_exact_match","WFO_exact_match_name","WFO_best_match","WFO_candidate_names","WFO_placement",
|
74 |
+
|
75 |
+
"GEO_override_OCR", "GEO_method", "GEO_formatted_full_string", "GEO_decimal_lat",
|
76 |
+
"GEO_decimal_long","GEO_city", "GEO_county", "GEO_state",
|
77 |
+
"GEO_state_code", "GEO_country", "GEO_country_code", "GEO_continent",
|
78 |
+
|
79 |
+
"tokens_in", "tokens_out", "path_to_crop","path_to_original","path_to_content","path_to_helper",]
|
80 |
|
81 |
+
self.do_create_OCR_helper_image = self.cfg['leafmachine']['do_create_OCR_helper_image']
|
82 |
|
83 |
self.map_prompt_versions()
|
84 |
self.map_dir_labels()
|
85 |
self.map_API_options()
|
86 |
+
# self.init_embeddings()
|
87 |
self.init_transcription_xlsx()
|
88 |
+
self.init_trOCR_model()
|
89 |
|
90 |
'''Logging'''
|
91 |
self.logger.info(f'Transcribing dataset --- {self.dir_labels}')
|
|
|
97 |
self.logger.info(f' Model name passed to API --> {self.model_name}')
|
98 |
self.logger.info(f' API access token is found in PRIVATE_DATA.yaml --> {self.has_key}')
|
99 |
|
100 |
+
def init_trOCR_model(self):
|
101 |
+
lgr = logging.getLogger('transformers')
|
102 |
+
lgr.setLevel(logging.ERROR)
|
103 |
+
|
104 |
+
self.trOCR_processor = TrOCRProcessor.from_pretrained(self.trOCR_model_version)
|
105 |
+
self.trOCR_model = VisionEncoderDecoderModel.from_pretrained(self.trOCR_model_version)
|
106 |
+
|
107 |
+
# Check for GPU availability
|
108 |
+
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
109 |
+
self.trOCR_model.to(self.device)
|
110 |
+
|
|
|
|
|
|
|
111 |
def map_API_options(self):
|
112 |
+
self.chat_version = self.cfg['leafmachine']['LLM_version']
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
113 |
|
114 |
+
# Get the required values from ModelMaps
|
115 |
+
self.model_name = ModelMaps.get_version_mapping_cost(self.chat_version)
|
116 |
+
self.is_azure = ModelMaps.get_version_mapping_is_azure(self.chat_version)
|
117 |
+
self.has_key = ModelMaps.get_version_has_key(self.chat_version, self.has_key_openai, self.has_key_azure_openai, self.has_key_palm2, self.has_key_mistral)
|
118 |
|
119 |
+
# Check if the version is supported
|
120 |
+
if self.model_name is None:
|
121 |
+
supported_LLMs = ", ".join(ModelMaps.get_models_gui_list())
|
122 |
+
raise Exception(f"Unsupported LLM: {self.chat_version}. Requires one of: {supported_LLMs}")
|
123 |
+
|
124 |
+
self.version_name = self.chat_version
|
125 |
|
126 |
def map_prompt_versions(self):
|
127 |
self.prompt_version_map = {
|
|
|
138 |
def is_in_prompt_version_map(self, value):
|
139 |
return value in self.prompt_version_map.values()
|
140 |
|
141 |
+
# def init_embeddings(self):
|
142 |
+
# if self.use_domain_knowledge:
|
143 |
+
# self.logger.info(f'*** USING DOMAIN KNOWLEDGE ***')
|
144 |
+
# self.logger.info(f'*** Initializing vector embeddings database ***')
|
145 |
+
# self.initialize_embeddings()
|
146 |
+
# else:
|
147 |
+
# self.Voucher_Vision_Embedding = None
|
148 |
|
149 |
def map_dir_labels(self):
|
150 |
if self.cfg['leafmachine']['use_RGB_label_images']:
|
|
|
165 |
|
166 |
def generate_xlsx_headers(self):
|
167 |
# Extract headers from the 'Dictionary' keys in the JSON template rules
|
168 |
+
# xlsx_headers = list(self.rules_config_json['rules']["Dictionary"].keys())
|
169 |
+
xlsx_headers = list(self.rules_config_json['rules'].keys())
|
170 |
xlsx_headers = xlsx_headers + self.utility_headers
|
171 |
return xlsx_headers
|
172 |
|
173 |
def init_transcription_xlsx(self):
|
174 |
+
# self.HEADERS_v1_n22 = ["Catalog Number","Genus","Species","subspecies","variety","forma","Country","State","County","Locality Name","Min Elevation","Max Elevation","Elevation Units","Verbatim Coordinates","Datum","Cultivated","Habitat","Collectors","Collector Number","Verbatim Date","Date","End Date"]
|
175 |
+
# self.HEADERS_v2_n26 = ["catalog_number","genus","species","subspecies","variety","forma","country","state","county","locality_name","min_elevation","max_elevation","elevation_units","verbatim_coordinates","decimal_coordinates","datum","cultivated","habitat","plant_description","collectors","collector_number","determined_by","multiple_names","verbatim_date","date","end_date"]
|
176 |
+
# self.HEADERS_v1_n22 = self.HEADERS_v1_n22 + self.utility_headers
|
177 |
+
# self.HEADERS_v2_n26 = self.HEADERS_v2_n26 + self.utility_headers
|
178 |
# Initialize output file
|
179 |
self.path_transcription = os.path.join(self.Dirs.transcription,"transcribed.xlsx")
|
180 |
|
181 |
+
# if self.prompt_version in ['prompt_v2_json_rules','prompt_v2_palm2']:
|
182 |
+
# self.headers = self.HEADERS_v2_n26
|
183 |
+
# self.headers_used = 'HEADERS_v2_n26'
|
184 |
|
185 |
+
# elif self.prompt_version in ['prompt_v1_verbose', 'prompt_v1_verbose_noDomainKnowledge','prompt_v1_palm2', 'prompt_v1_palm2_noDomainKnowledge']:
|
186 |
+
# self.headers = self.HEADERS_v1_n22
|
187 |
+
# self.headers_used = 'HEADERS_v1_n22'
|
188 |
|
189 |
+
# else:
|
190 |
+
if not self.is_predefined_prompt:
|
191 |
+
# Load the rules configuration
|
192 |
+
self.rules_config_json = self.load_rules_config()
|
193 |
+
# Generate the headers from the configuration
|
194 |
+
self.headers = self.generate_xlsx_headers()
|
195 |
+
# Set the headers used to the dynamically generated headers
|
196 |
+
self.headers_used = 'CUSTOM'
|
197 |
else:
|
198 |
+
# If it's a predefined prompt, raise an exception as we don't have further instructions
|
199 |
+
raise ValueError("Predefined prompt is not handled in this context.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
200 |
|
201 |
self.create_or_load_excel_with_headers(os.path.join(self.Dirs.transcription,"transcribed.xlsx"), self.headers)
|
202 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
203 |
|
204 |
def create_or_load_excel_with_headers(self, file_path, headers, show_head=False):
|
205 |
output_dir_names = ['Archival_Components', 'Config_File', 'Cropped_Images', 'Logs', 'Original_Images', 'Transcription']
|
|
|
301 |
|
302 |
|
303 |
|
304 |
+
def add_data_to_excel_from_response(self, path_transcription, response, WFO_record, GEO_record, filename_without_extension, path_to_crop, path_to_content, path_to_helper, nt_in, nt_out):
|
305 |
+
geo_headers = ["GEO_override_OCR", "GEO_method", "GEO_formatted_full_string", "GEO_decimal_lat",
|
306 |
+
"GEO_decimal_long","GEO_city", "GEO_county", "GEO_state",
|
307 |
+
"GEO_state_code", "GEO_country", "GEO_country_code", "GEO_continent",]
|
308 |
+
|
309 |
+
# WFO_candidate_names is separate, bc it may be type --> list
|
310 |
+
wfo_headers = ["WFO_override_OCR", "WFO_exact_match","WFO_exact_match_name","WFO_best_match","WFO_placement"]
|
311 |
+
|
312 |
wb = openpyxl.load_workbook(path_transcription)
|
313 |
sheet = wb.active
|
314 |
|
|
|
369 |
sheet.cell(row=next_row, column=i, value=nt_in)
|
370 |
elif header.value == "tokens_out":
|
371 |
sheet.cell(row=next_row, column=i, value=nt_out)
|
372 |
+
elif header.value == "filename":
|
373 |
+
sheet.cell(row=next_row, column=i, value=filename_without_extension)
|
374 |
+
|
375 |
+
# "WFO_exact_match","WFO_exact_match_name","WFO_best_match","WFO_candidate_names","WFO_placement"
|
376 |
+
elif header.value in wfo_headers:
|
377 |
+
sheet.cell(row=next_row, column=i, value=WFO_record.get(header.value, ''))
|
378 |
+
# elif header.value == "WFO_exact_match":
|
379 |
+
# sheet.cell(row=next_row, column=i, value= WFO_record.get("WFO_exact_match",''))
|
380 |
+
# elif header.value == "WFO_exact_match_name":
|
381 |
+
# sheet.cell(row=next_row, column=i, value= WFO_record.get("WFO_exact_match_name",''))
|
382 |
+
# elif header.value == "WFO_best_match":
|
383 |
+
# sheet.cell(row=next_row, column=i, value= WFO_record.get("WFO_best_match",''))
|
384 |
+
# elif header.value == "WFO_placement":
|
385 |
+
# sheet.cell(row=next_row, column=i, value= WFO_record.get("WFO_placement",''))
|
386 |
+
elif header.value == "WFO_candidate_names":
|
387 |
+
candidate_names = WFO_record.get("WFO_candidate_names", '')
|
388 |
+
# Check if candidate_names is a list and convert to a string if it is
|
389 |
+
if isinstance(candidate_names, list):
|
390 |
+
candidate_names_str = '|'.join(candidate_names)
|
391 |
+
else:
|
392 |
+
candidate_names_str = candidate_names
|
393 |
+
sheet.cell(row=next_row, column=i, value=candidate_names_str)
|
394 |
+
|
395 |
+
# "GEO_method", "GEO_formatted_full_string", "GEO_decimal_lat", "GEO_decimal_long",
|
396 |
+
# "GEO_city", "GEO_county", "GEO_state", "GEO_state_code", "GEO_country", "GEO_country_code", "GEO_continent"
|
397 |
+
elif header.value in geo_headers:
|
398 |
+
sheet.cell(row=next_row, column=i, value=GEO_record.get(header.value, ''))
|
399 |
+
|
400 |
# save the workbook
|
401 |
wb.save(path_transcription)
|
402 |
+
|
403 |
|
404 |
+
def has_API_key(self, val):
|
405 |
+
if val != '':
|
406 |
+
return True
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
407 |
else:
|
408 |
+
return False
|
409 |
+
|
410 |
+
|
411 |
+
|
412 |
+
|
413 |
+
|
414 |
+
|
415 |
|
416 |
def set_API_keys(self):
|
417 |
+
if self.is_hf:
|
418 |
+
openai_api_key = os.getenv('OPENAI_API_KEY')
|
419 |
+
google_application_credentials = os.getenv('GOOGLE_APPLICATION_CREDENTIALS')
|
420 |
+
palm_api_key = os.getenv('PALM_API_KEY')
|
421 |
+
mistral_api_key = os.getenv('MISTRAL_API_KEY')
|
422 |
+
here_api_key = os.getenv('here_api_key')
|
423 |
+
here_app_id = os.getenv('here_app_id')
|
424 |
+
open_cage_api_key = os.getenv('open_cage_geocode')
|
425 |
+
google_project_id = os.getenv('GOOGLE_PROJECT_ID')
|
426 |
+
google_project_location = os.getenv('GOOGLE_LOCATION')
|
427 |
+
|
428 |
+
|
429 |
+
|
430 |
+
self.has_key_openai = openai_api_key is not None
|
431 |
+
self.has_key_google_OCR = google_application_credentials is not None
|
432 |
+
self.has_key_palm2 = palm_api_key is not None
|
433 |
+
self.has_key_mistral = mistral_api_key is not None
|
434 |
+
self.has_key_here = here_api_key is not None
|
435 |
+
self.has_hey_here = here_app_id is not None
|
436 |
+
self.has_open_cage_geocode = open_cage_api_key is not None
|
437 |
+
self.has_key_google_project_id = google_project_id is not None
|
438 |
+
self.has_key_google_project_location = google_project_location is not None
|
439 |
+
|
440 |
+
if self.has_key_openai:
|
441 |
+
openai.api_key = openai_api_key
|
442 |
+
|
443 |
+
if self.has_key_google_project_id and self.has_key_google_project_location:
|
444 |
+
vertexai.init(project=os.getenv('GOOGLE_PROJECT_ID'), location=os.getenv('GOOGLE_LOCATION'))
|
445 |
+
|
446 |
+
if os.getenv('AZURE_API_KEY') is not None:
|
447 |
+
azure_api_version = os.getenv('AZURE_API_VERSION')
|
448 |
+
azure_api_key = os.getenv('AZURE_API_KEY')
|
449 |
+
azure_api_base = os.getenv('AZURE_API_BASE')
|
450 |
+
azure_organization = os.getenv('AZURE_ORGANIZATION')
|
451 |
+
azure_api_type = os.getenv('AZURE_API_TYPE')
|
452 |
+
azure_deployment_name = os.getenv('AZURE_DEPLOYMENT_NAME')
|
453 |
+
|
454 |
+
if azure_api_version and azure_api_key and azure_api_base and azure_organization and azure_api_type and azure_deployment_name:
|
455 |
+
|
456 |
+
self.has_key_azure_openai = True
|
457 |
+
self.llm = AzureChatOpenAI(
|
458 |
+
deployment_name = 'gpt-35-turbo',#'gpt-35-turbo',
|
459 |
+
openai_api_version = azure_api_version,
|
460 |
+
openai_api_key = azure_api_key,
|
461 |
+
azure_endpoint = azure_api_base,
|
462 |
+
openai_organization = azure_organization,
|
463 |
+
)
|
464 |
+
|
465 |
+
|
466 |
+
else:
|
467 |
+
self.dir_home = os.path.dirname(os.path.dirname(__file__))
|
468 |
+
self.path_cfg_private = os.path.join(self.dir_home, 'PRIVATE_DATA.yaml')
|
469 |
+
self.cfg_private = get_cfg_from_full_path(self.path_cfg_private)
|
470 |
+
|
471 |
+
self.has_key_openai = self.has_API_key(self.cfg_private['openai']['OPENAI_API_KEY'])
|
472 |
+
|
473 |
+
self.has_key_azure_openai = self.has_API_key(self.cfg_private['openai_azure']['api_version'])
|
474 |
+
|
475 |
+
self.has_key_google_OCR = self.has_API_key(self.cfg_private['google_cloud']['path_json_file'])
|
476 |
+
|
477 |
+
self.has_key_palm2 = self.has_API_key(self.cfg_private['google_palm']['google_palm_api'])
|
478 |
+
self.has_key_google_project_id = self.has_API_key(self.cfg_private['google_palm']['project_id'])
|
479 |
+
self.has_key_google_project_location = self.has_API_key(self.cfg_private['google_palm']['location'])
|
480 |
+
|
481 |
+
self.has_key_mistral = self.has_API_key(self.cfg_private['mistral']['mistral_key'])
|
482 |
+
|
483 |
+
self.has_key_here = self.has_API_key(self.cfg_private['here']['api_key'])
|
484 |
+
|
485 |
+
self.has_open_cage_geocode = self.has_API_key(self.cfg_private['open_cage_geocode']['api_key'])
|
486 |
+
|
487 |
+
if self.has_key_openai:
|
488 |
+
openai.api_key = self.cfg_private['openai']['OPENAI_API_KEY']
|
489 |
+
os.environ["OPENAI_API_KEY"] = self.cfg_private['openai']['OPENAI_API_KEY']
|
490 |
+
|
491 |
+
if self.has_key_azure_openai:
|
492 |
+
# os.environ["OPENAI_API_KEY"] = self.cfg_private['openai_azure']['openai_api_key']
|
493 |
+
self.llm = AzureChatOpenAI(
|
494 |
+
deployment_name = 'gpt-35-turbo',#'gpt-35-turbo',
|
495 |
+
openai_api_version = self.cfg_private['openai_azure']['api_version'],
|
496 |
+
openai_api_key = self.cfg_private['openai_azure']['openai_api_key'],
|
497 |
+
azure_endpoint = self.cfg_private['openai_azure']['openai_api_base'],
|
498 |
+
# openai_api_base=self.cfg_private['openai_azure']['openai_api_base'],
|
499 |
+
openai_organization = self.cfg_private['openai_azure']['openai_organization'],
|
500 |
+
# openai_api_type = self.cfg_private['openai_azure']['openai_api_type']
|
501 |
+
)
|
502 |
+
|
503 |
+
# This is frustrating. a #TODO is to figure out when/why these methods conflict with the permissions set in the Palm/Gemini calls
|
504 |
+
name_check = self.cfg['leafmachine']['LLM_version'].lower().split(' ')
|
505 |
+
if ('google' in name_check) or( 'palm' in name_check) or ('gemini' in name_check):
|
506 |
+
os.environ['GOOGLE_PROJECT_ID'] = self.cfg_private['google_palm']['project_id'] # gemini
|
507 |
+
os.environ['GOOGLE_LOCATION'] = self.cfg_private['google_palm']['location'] # gemini
|
508 |
+
# genai.configure(api_key=self.cfg_private['google_palm']['google_palm_api'])
|
509 |
+
vertexai.init(project=os.environ['GOOGLE_PROJECT_ID'], location=os.environ['GOOGLE_LOCATION'])
|
510 |
+
# os.environ.pop("GOOGLE_APPLICATION_CREDENTIALS", None)
|
511 |
+
# os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = self.cfg_private['google_cloud']['path_json_file'] ####
|
512 |
+
# os.environ['GOOGLE_API_KEY'] = self.cfg_private['google_palm']['google_palm_api']
|
513 |
+
|
514 |
+
|
515 |
+
##### NOTE: this is how you can use ONLY OCR. If you get a vertexAI login it should work without loading all this
|
516 |
+
# else:
|
517 |
+
# if self.has_key_google_OCR:
|
518 |
+
# if os.path.exists(self.cfg_private['google_cloud']['path_json_file']):
|
519 |
+
# os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = self.cfg_private['google_cloud']['path_json_file']
|
520 |
+
# elif os.path.exists(self.cfg_private['google_cloud']['path_json_file_service_account2']):
|
521 |
+
# os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = self.cfg_private['google_cloud']['path_json_file_service_account2']
|
522 |
+
# else:
|
523 |
+
# raise f"Google JSON API key file not found"
|
524 |
+
|
525 |
+
##### NOTE: This should also be covered by vertexAI now
|
526 |
+
# if self.has_key_palm2:
|
527 |
+
# os.environ['PALM'] = self.cfg_private['google_palm']['google_palm_api']
|
528 |
+
# os.environ['GOOGLE_PROJECT_ID'] = self.cfg_private['google_palm']['project_id'] # gemini
|
529 |
+
# os.environ['GOOGLE_LOCATION'] = self.cfg_private['google_palm']['location'] # gemini
|
530 |
+
# os.environ['GOOGLE_API_KEY'] = self.cfg_private['google_palm']['google_palm_api']
|
531 |
+
|
532 |
+
if self.has_key_mistral:
|
533 |
+
os.environ['MISTRAL_API_KEY'] = self.cfg_private['mistral']['mistral_key']
|
534 |
+
|
535 |
+
if self.has_key_here:
|
536 |
+
os.environ['here_app_id'] = self.cfg_private['here']['app_id']
|
537 |
+
os.environ['here_api_key'] = self.cfg_private['here']['api_key']
|
538 |
+
|
539 |
+
if self.has_open_cage_geocode:
|
540 |
+
os.environ['open_cage_geocode'] = self.cfg_private['open_cage_geocode']['api_key']
|
541 |
+
|
542 |
+
|
543 |
|
544 |
+
# def initialize_embeddings(self):
|
545 |
+
# '''Loading embedding search __init__(self, db_name, path_domain_knowledge, logger, build_new_db=False, model_name="hkunlp/instructor-xl", device="cuda")'''
|
546 |
+
# self.Voucher_Vision_Embedding = VoucherVisionEmbedding(self.db_name, self.path_domain_knowledge, logger=self.logger, build_new_db=self.build_new_db)
|
547 |
|
|
|
|
|
|
|
548 |
|
549 |
def clean_catalog_number(self, data, filename_without_extension):
|
550 |
#Cleans up the catalog number in data if it's a dict
|
|
|
568 |
if self.headers_used == 'HEADERS_v1_n22':
|
569 |
return modify_catalog_key("Catalog Number", filename_without_extension, data)
|
570 |
elif self.headers_used in ['HEADERS_v2_n26', 'CUSTOM']:
|
571 |
+
return modify_catalog_key("filename", filename_without_extension, data)
|
572 |
else:
|
573 |
raise ValueError("Invalid headers used.")
|
574 |
else:
|
|
|
582 |
data = json.dumps(data, indent=4, sort_keys=False)
|
583 |
txt_file.write(data)
|
584 |
|
585 |
+
|
586 |
+
# def create_null_json(self):
|
587 |
+
# return {}
|
588 |
|
589 |
+
|
590 |
def remove_non_numbers(self, s):
|
591 |
return ''.join([char for char in s if char.isdigit()])
|
592 |
|
593 |
+
|
594 |
def create_null_row(self, filename_without_extension, path_to_crop, path_to_content, path_to_helper):
|
595 |
json_dict = {header: '' for header in self.headers}
|
596 |
for header, value in json_dict.items():
|
597 |
+
if header == "path_to_crop":
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
598 |
json_dict[header] = path_to_crop
|
599 |
elif header == "path_to_original":
|
600 |
fname = os.path.basename(path_to_crop)
|
|
|
605 |
json_dict[header] = path_to_content
|
606 |
elif header == "path_to_helper":
|
607 |
json_dict[header] = path_to_helper
|
608 |
+
elif header == "filename":
|
609 |
+
json_dict[header] = filename_without_extension
|
610 |
+
|
611 |
+
# "WFO_exact_match","WFO_exact_match_name","WFO_best_match","WFO_candidate_names","WFO_placement"
|
612 |
+
elif header == "WFO_exact_match":
|
613 |
+
json_dict[header] =''
|
614 |
+
elif header == "WFO_exact_match_name":
|
615 |
+
json_dict[header] = ''
|
616 |
+
elif header == "WFO_best_match":
|
617 |
+
json_dict[header] = ''
|
618 |
+
elif header == "WFO_candidate_names":
|
619 |
+
json_dict[header] = ''
|
620 |
+
elif header == "WFO_placement":
|
621 |
+
json_dict[header] = ''
|
622 |
return json_dict
|
623 |
+
|
624 |
|
625 |
+
##################################################################################################################################
|
626 |
+
################################################## OCR ##################################################################
|
627 |
+
##################################################################################################################################
|
628 |
+
def perform_OCR_and_save_results(self, image_index, jpg_file_path_OCR_helper, txt_file_path_OCR, txt_file_path_OCR_bounds):
|
629 |
+
self.logger.info(f'Working on {image_index + 1}/{len(self.img_paths)} --- Starting OCR')
|
630 |
+
# self.OCR - None
|
631 |
|
632 |
+
### Process_image() runs the OCR for text, handwriting, trOCR AND creates the overlay image
|
633 |
+
ocr_google = OCRGoogle(self.path_to_crop, self.cfg, self.trOCR_model_version, self.trOCR_model, self.trOCR_processor, self.device)
|
634 |
+
ocr_google.process_image(self.do_create_OCR_helper_image, self.logger)
|
635 |
+
self.OCR = ocr_google.OCR
|
636 |
|
637 |
+
self.write_json_to_file(txt_file_path_OCR, ocr_google.OCR_JSON_to_file)
|
638 |
+
|
639 |
+
self.logger.info(f'Working on {image_index + 1}/{len(self.img_paths)} --- Finished OCR')
|
|
|
|
|
|
|
640 |
|
641 |
+
if len(self.OCR) > 0:
|
642 |
+
ocr_google.overlay_image.save(jpg_file_path_OCR_helper)
|
643 |
|
644 |
+
OCR_bounds = {}
|
645 |
+
if ocr_google.hand_text_to_box_mapping is not None:
|
646 |
+
OCR_bounds['OCR_bounds_handwritten'] = ocr_google.hand_text_to_box_mapping
|
647 |
|
648 |
+
if ocr_google.normal_text_to_box_mapping is not None:
|
649 |
+
OCR_bounds['OCR_bounds_printed'] = ocr_google.normal_text_to_box_mapping
|
|
|
|
|
|
|
650 |
|
651 |
+
if ocr_google.trOCR_text_to_box_mapping is not None:
|
652 |
+
OCR_bounds['OCR_bounds_trOCR'] = ocr_google.trOCR_text_to_box_mapping
|
653 |
|
654 |
+
self.write_json_to_file(txt_file_path_OCR_bounds, OCR_bounds)
|
655 |
+
self.logger.info(f'Working on {image_index + 1}/{len(self.img_paths)} --- Saved OCR Overlay Image')
|
656 |
+
else:
|
657 |
+
pass ########################################################################################################################### fix logic for no OCR
|
658 |
|
659 |
+
##################################################################################################################################
|
660 |
+
####################################################### LLM Switchboard ########################################################
|
661 |
+
##################################################################################################################################
|
662 |
+
def send_to_LLM(self, is_azure, progress_report, json_report, model_name):
|
663 |
+
self.n_failed_LLM_calls = 0
|
664 |
+
self.n_failed_OCR = 0
|
665 |
|
666 |
+
final_JSON_response = None
|
667 |
+
final_WFO_record = None
|
668 |
+
final_GEO_record = None
|
669 |
|
670 |
+
self.initialize_token_counters()
|
671 |
+
self.update_progress_report_initial(progress_report)
|
672 |
+
|
673 |
+
MODEL_NAME_FORMATTED = ModelMaps.get_API_name(model_name)
|
674 |
+
name_parts = model_name.split("_")
|
675 |
+
|
676 |
+
self.setup_JSON_dict_structure()
|
677 |
|
678 |
+
json_report.set_text(text_main=f'Loading {MODEL_NAME_FORMATTED}')
|
679 |
+
json_report.set_JSON({}, {}, {})
|
680 |
+
llm_model = self.initialize_llm_model(self.logger, MODEL_NAME_FORMATTED, self.JSON_dict_structure, name_parts, is_azure, self.llm)
|
|
|
|
|
681 |
|
682 |
+
for i, path_to_crop in enumerate(self.img_paths):
|
683 |
+
self.update_progress_report_batch(progress_report, i)
|
684 |
|
685 |
+
if self.should_skip_specimen(path_to_crop):
|
686 |
+
self.log_skipping_specimen(path_to_crop)
|
687 |
+
continue
|
688 |
|
689 |
+
paths = self.generate_paths(path_to_crop, i)
|
690 |
+
self.path_to_crop = path_to_crop
|
691 |
|
692 |
+
filename_without_extension, txt_file_path, txt_file_path_OCR, txt_file_path_OCR_bounds, jpg_file_path_OCR_helper = paths
|
693 |
+
json_report.set_text(text_main='Starting OCR')
|
694 |
+
self.perform_OCR_and_save_results(i, jpg_file_path_OCR_helper, txt_file_path_OCR, txt_file_path_OCR_bounds)
|
695 |
+
json_report.set_text(text_main='Finished OCR')
|
696 |
|
697 |
+
if not self.OCR:
|
698 |
+
self.n_failed_OCR += 1
|
699 |
+
response_candidate = None
|
700 |
+
nt_in = 0
|
701 |
+
nt_out = 0
|
702 |
+
else:
|
703 |
+
### Format prompt
|
704 |
+
prompt = self.setup_prompt()
|
705 |
+
prompt = remove_colons_and_double_apostrophes(prompt)
|
706 |
|
707 |
+
### Send prompt to chosen LLM
|
708 |
+
self.logger.info(f'Waiting for {model_name} API call --- Using {MODEL_NAME_FORMATTED}')
|
|
|
|
|
|
|
709 |
|
710 |
+
if 'PALM2' in name_parts:
|
711 |
+
response_candidate, nt_in, nt_out, WFO_record, GEO_record = llm_model.call_llm_api_GooglePalm2(prompt, json_report)
|
712 |
+
|
713 |
+
elif 'GEMINI' in name_parts:
|
714 |
+
response_candidate, nt_in, nt_out, WFO_record, GEO_record = llm_model.call_llm_api_GoogleGemini(prompt, json_report)
|
715 |
+
|
716 |
+
elif 'MISTRAL' in name_parts and ('LOCAL' not in name_parts):
|
717 |
+
response_candidate, nt_in, nt_out, WFO_record, GEO_record = llm_model.call_llm_api_MistralAI(prompt, json_report)
|
718 |
+
|
719 |
+
elif 'LOCAL' in name_parts:
|
720 |
+
if 'MISTRAL' in name_parts or 'MIXTRAL' in name_parts:
|
721 |
+
if 'CPU' in name_parts:
|
722 |
+
response_candidate, nt_in, nt_out, WFO_record, GEO_record = llm_model.call_llm_local_cpu_MistralAI(prompt, json_report)
|
723 |
+
else:
|
724 |
+
response_candidate, nt_in, nt_out, WFO_record, GEO_record = llm_model.call_llm_local_MistralAI(prompt, json_report)
|
725 |
+
else:
|
726 |
+
response_candidate, nt_in, nt_out, WFO_record, GEO_record = llm_model.call_llm_api_OpenAI(prompt, json_report)
|
727 |
|
728 |
+
self.n_failed_LLM_calls += 1 if response_candidate is None else 0
|
729 |
+
|
730 |
+
### Estimate n tokens returned
|
731 |
+
self.logger.info(f'Prompt tokens IN --- {nt_in}')
|
732 |
+
self.logger.info(f'Prompt tokens OUT --- {nt_out}')
|
733 |
+
|
734 |
+
self.update_token_counters(nt_in, nt_out)
|
735 |
|
736 |
+
final_JSON_response, final_WFO_record, final_GEO_record = self.update_final_response(response_candidate, WFO_record, GEO_record, paths, path_to_crop, nt_in, nt_out)
|
737 |
|
738 |
+
self.log_completion_info(final_JSON_response)
|
739 |
|
740 |
+
json_report.set_JSON(final_JSON_response, final_WFO_record, final_GEO_record)
|
741 |
+
|
742 |
+
self.update_progress_report_final(progress_report)
|
743 |
+
final_JSON_response = self.parse_final_json_response(final_JSON_response)
|
744 |
+
return final_JSON_response, final_WFO_record, final_GEO_record, self.total_tokens_in, self.total_tokens_out
|
745 |
+
|
|
|
|
|
|
|
746 |
|
747 |
+
##################################################################################################################################
|
748 |
+
################################################## LLM Helper Funcs ##############################################################
|
749 |
+
##################################################################################################################################
|
750 |
+
def initialize_llm_model(self, logger, model_name, JSON_dict_structure, name_parts, is_azure=None, llm_object=None):
|
751 |
+
if 'LOCAL'in name_parts:
|
752 |
+
if ('MIXTRAL' in name_parts) or ('MISTRAL' in name_parts):
|
753 |
+
if 'CPU' in name_parts:
|
754 |
+
return LocalCPUMistralHandler(logger, model_name, JSON_dict_structure)
|
755 |
+
else:
|
756 |
+
return LocalMistralHandler(logger, model_name, JSON_dict_structure)
|
757 |
+
else:
|
758 |
+
if 'PALM2' in name_parts:
|
759 |
+
return GooglePalm2Handler(logger, model_name, JSON_dict_structure)
|
760 |
+
elif 'GEMINI' in name_parts:
|
761 |
+
return GoogleGeminiHandler(logger, model_name, JSON_dict_structure)
|
762 |
+
elif 'MISTRAL' in name_parts and ('LOCAL' not in name_parts):
|
763 |
+
return MistralHandler(logger, model_name, JSON_dict_structure)
|
764 |
else:
|
765 |
+
return OpenAIHandler(logger, model_name, JSON_dict_structure, is_azure, llm_object)
|
766 |
+
|
767 |
+
def setup_prompt(self):
|
768 |
+
Catalog = PromptCatalog()
|
769 |
+
prompt, _ = Catalog.prompt_SLTP(self.path_custom_prompts, OCR=self.OCR)
|
770 |
+
return prompt
|
771 |
+
|
772 |
+
def setup_JSON_dict_structure(self):
|
773 |
+
Catalog = PromptCatalog()
|
774 |
+
_, self.JSON_dict_structure = Catalog.prompt_SLTP(self.path_custom_prompts, OCR='Text')
|
775 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
776 |
|
777 |
+
def initialize_token_counters(self):
|
778 |
+
self.total_tokens_in = 0
|
779 |
+
self.total_tokens_out = 0
|
780 |
|
781 |
+
|
782 |
+
def update_progress_report_initial(self, progress_report):
|
|
|
|
|
783 |
if progress_report is not None:
|
784 |
progress_report.set_n_batches(len(self.img_paths))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
785 |
|
|
|
|
|
|
|
|
|
|
|
786 |
|
787 |
+
def update_progress_report_batch(self, progress_report, batch_index):
|
788 |
if progress_report is not None:
|
789 |
+
progress_report.update_batch(f"Working on image {batch_index + 1} of {len(self.img_paths)}")
|
|
|
790 |
|
791 |
|
792 |
+
def should_skip_specimen(self, path_to_crop):
|
793 |
+
return os.path.basename(path_to_crop) in self.completed_specimens
|
794 |
+
|
795 |
+
|
796 |
+
def log_skipping_specimen(self, path_to_crop):
|
797 |
+
self.logger.info(f'[Skipping] specimen {os.path.basename(path_to_crop)} already processed')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
798 |
|
799 |
+
|
800 |
+
def update_token_counters(self, nt_in, nt_out):
|
801 |
+
self.total_tokens_in += nt_in
|
802 |
+
self.total_tokens_out += nt_out
|
803 |
+
|
804 |
+
|
805 |
+
def update_final_response(self, response_candidate, WFO_record, GEO_record, paths, path_to_crop, nt_in, nt_out):
|
806 |
+
filename_without_extension, txt_file_path, txt_file_path_OCR, txt_file_path_OCR_bounds, jpg_file_path_OCR_helper = paths
|
807 |
+
# Saving the JSON and XLSX files with the response and updating the final JSON response
|
808 |
+
if response_candidate is not None:
|
809 |
+
final_JSON_response_updated = self.save_json_and_xlsx(response_candidate, WFO_record, GEO_record, filename_without_extension, path_to_crop, txt_file_path, jpg_file_path_OCR_helper, nt_in, nt_out)
|
810 |
+
return final_JSON_response_updated, WFO_record, GEO_record
|
811 |
+
else:
|
812 |
+
final_JSON_response_updated = self.save_json_and_xlsx(response_candidate, WFO_record, GEO_record, filename_without_extension, path_to_crop, txt_file_path, jpg_file_path_OCR_helper, nt_in, nt_out)
|
813 |
+
return final_JSON_response_updated, WFO_record, GEO_record
|
814 |
|
815 |
+
|
816 |
+
def log_completion_info(self, final_JSON_response):
|
817 |
+
self.logger.info(f'Formatted JSON\n{final_JSON_response}')
|
818 |
+
self.logger.info(f'Finished API calls\n')
|
819 |
+
|
820 |
+
|
821 |
+
def update_progress_report_final(self, progress_report):
|
822 |
+
if progress_report is not None:
|
823 |
+
progress_report.reset_batch("Batch Complete")
|
824 |
+
|
825 |
+
|
826 |
+
def parse_final_json_response(self, final_JSON_response):
|
827 |
+
try:
|
828 |
+
return json.loads(final_JSON_response.strip('```').replace('json\n', '', 1).replace('json', '', 1))
|
829 |
+
except:
|
830 |
+
return final_JSON_response
|
831 |
+
|
832 |
+
|
833 |
|
834 |
def generate_paths(self, path_to_crop, i):
|
835 |
filename_without_extension = os.path.splitext(os.path.basename(path_to_crop))[0]
|
|
|
842 |
|
843 |
return filename_without_extension, txt_file_path, txt_file_path_OCR, txt_file_path_OCR_bounds, jpg_file_path_OCR_helper
|
844 |
|
845 |
+
|
846 |
+
def save_json_and_xlsx(self, response, WFO_record, GEO_record, filename_without_extension, path_to_crop, txt_file_path, jpg_file_path_OCR_helper, nt_in, nt_out):
|
847 |
if response is None:
|
848 |
+
response = self.JSON_dict_structure
|
849 |
+
# Insert 'filename' as the first key
|
850 |
+
response = {'filename': filename_without_extension, **{k: v for k, v in response.items() if k != 'filename'}}
|
851 |
self.write_json_to_file(txt_file_path, response)
|
852 |
|
853 |
# Then add the null info to the spreadsheet
|
854 |
response_null = self.create_null_row(filename_without_extension, path_to_crop, txt_file_path, jpg_file_path_OCR_helper)
|
855 |
+
self.add_data_to_excel_from_response(self.path_transcription, response_null, WFO_record, GEO_record, filename_without_extension, path_to_crop, txt_file_path, jpg_file_path_OCR_helper, nt_in=0, nt_out=0)
|
856 |
|
857 |
### Set completed JSON
|
858 |
else:
|
859 |
response = self.clean_catalog_number(response, filename_without_extension)
|
860 |
self.write_json_to_file(txt_file_path, response)
|
861 |
# add to the xlsx file
|
862 |
+
self.add_data_to_excel_from_response(self.path_transcription, response, WFO_record, GEO_record, filename_without_extension, path_to_crop, txt_file_path, jpg_file_path_OCR_helper, nt_in, nt_out)
|
863 |
return response
|
864 |
|
865 |
+
|
866 |
+
def process_specimen_batch(self, progress_report, json_report, is_real_run=False):
|
867 |
+
if not self.has_key:
|
868 |
+
self.logger.error(f'No API key found for {self.version_name}')
|
869 |
+
raise Exception(f"No API key found for {self.version_name}")
|
870 |
+
|
871 |
try:
|
872 |
+
if is_real_run:
|
873 |
+
progress_report.update_overall(f"Transcribing Labels")
|
874 |
+
|
875 |
+
final_json_response, final_WFO_record, final_GEO_record, total_tokens_in, total_tokens_out = self.send_to_LLM(self.is_azure, progress_report, json_report, self.model_name)
|
876 |
+
|
877 |
+
return final_json_response, final_WFO_record, final_GEO_record, total_tokens_in, total_tokens_out
|
878 |
+
|
879 |
+
except Exception as e:
|
880 |
+
self.logger.error(f"LLM call failed in process_specimen_batch: {e}")
|
|
|
881 |
if progress_report is not None:
|
882 |
progress_report.reset_batch(f"Batch Failed")
|
883 |
+
self.close_logger_handlers()
|
|
|
|
|
|
|
884 |
raise
|
885 |
|
886 |
+
|
887 |
+
def close_logger_handlers(self):
|
888 |
+
for handler in self.logger.handlers[:]:
|
889 |
+
handler.close()
|
890 |
+
self.logger.removeHandler(handler)
|
891 |
+
|
892 |
+
|
893 |
def process_specimen_batch_OCR_test(self, path_to_crop):
|
894 |
for img_filename in os.listdir(path_to_crop):
|
895 |
img_path = os.path.join(path_to_crop, img_filename)
|
896 |
+
self.OCR, self.bounds, self.text_to_box_mapping = detect_text(img_path)
|
897 |
|
898 |
|
899 |
|
vouchervision/utils_VoucherVision_batch.py
ADDED
@@ -0,0 +1,863 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import openai
|
2 |
+
import os, sys, json, inspect, glob, tiktoken, shutil, yaml, torch, logging
|
3 |
+
import openpyxl
|
4 |
+
from openpyxl import Workbook, load_workbook
|
5 |
+
import google.generativeai as genai
|
6 |
+
import vertexai
|
7 |
+
from langchain_openai import AzureChatOpenAI
|
8 |
+
from transformers import TrOCRProcessor, VisionEncoderDecoderModel
|
9 |
+
|
10 |
+
currentdir = os.path.dirname(os.path.abspath(
|
11 |
+
inspect.getfile(inspect.currentframe())))
|
12 |
+
parentdir = os.path.dirname(currentdir)
|
13 |
+
sys.path.append(parentdir)
|
14 |
+
parentdir = os.path.dirname(parentdir)
|
15 |
+
sys.path.append(parentdir)
|
16 |
+
|
17 |
+
from general_utils import get_cfg_from_full_path
|
18 |
+
# from embeddings_db import VoucherVisionEmbedding
|
19 |
+
from OCR_google_cloud_vision import OCRGoogle
|
20 |
+
|
21 |
+
from vouchervision.LLM_OpenAI import OpenAIHandler
|
22 |
+
from vouchervision.LLM_GooglePalm2 import GooglePalm2Handler
|
23 |
+
from vouchervision.LLM_GoogleGemini import GoogleGeminiHandler
|
24 |
+
from vouchervision.LLM_MistralAI import MistralHandler
|
25 |
+
from vouchervision.LLM_local_cpu_MistralAI import LocalCPUMistralHandler
|
26 |
+
from vouchervision.LLM_local_MistralAI import LocalMistralHandler #call_llm_local_MistralAI_8x7b
|
27 |
+
from vouchervision.utils_LLM import remove_colons_and_double_apostrophes
|
28 |
+
|
29 |
+
# from LLM_Falcon import OCR_to_dict_Falcon
|
30 |
+
from prompt_catalog import PromptCatalog
|
31 |
+
from vouchervision.model_maps import ModelMaps
|
32 |
+
|
33 |
+
'''
|
34 |
+
* For the prefix_removal, the image names have 'MICH-V-' prior to the barcode, so that is used for matching
|
35 |
+
but removed for output.
|
36 |
+
* There is also code active to replace the LLM-predicted "Catalog Number" with the correct number since it is known.
|
37 |
+
The LLMs to usually assign the barcode to the correct field, but it's not needed since it is already known.
|
38 |
+
- Look for ####################### Catalog Number pre-defined
|
39 |
+
'''
|
40 |
+
|
41 |
+
|
42 |
+
|
43 |
+
class VoucherVision():
|
44 |
+
|
45 |
+
def __init__(self, cfg, logger, dir_home, path_custom_prompts, Project, Dirs):
|
46 |
+
self.cfg = cfg
|
47 |
+
self.logger = logger
|
48 |
+
self.dir_home = dir_home
|
49 |
+
self.path_custom_prompts = path_custom_prompts
|
50 |
+
self.Project = Project
|
51 |
+
self.Dirs = Dirs
|
52 |
+
self.headers = None
|
53 |
+
self.prompt_version = None
|
54 |
+
|
55 |
+
# self.trOCR_model_version = "microsoft/trocr-large-handwritten"
|
56 |
+
self.trOCR_model_version = "microsoft/trocr-base-handwritten"
|
57 |
+
self.trOCR_processor = None
|
58 |
+
self.trOCR_model = None
|
59 |
+
|
60 |
+
self.set_API_keys()
|
61 |
+
self.setup()
|
62 |
+
|
63 |
+
|
64 |
+
def setup(self):
|
65 |
+
self.logger.name = f'[Transcription]'
|
66 |
+
self.logger.info(f'Setting up OCR and LLM')
|
67 |
+
|
68 |
+
self.db_name = self.cfg['leafmachine']['project']['embeddings_database_name']
|
69 |
+
self.path_domain_knowledge = self.cfg['leafmachine']['project']['path_to_domain_knowledge_xlsx']
|
70 |
+
self.build_new_db = self.cfg['leafmachine']['project']['build_new_embeddings_database']
|
71 |
+
|
72 |
+
self.continue_run_from_partial_xlsx = self.cfg['leafmachine']['project']['continue_run_from_partial_xlsx']
|
73 |
+
|
74 |
+
self.prefix_removal = self.cfg['leafmachine']['project']['prefix_removal']
|
75 |
+
self.suffix_removal = self.cfg['leafmachine']['project']['suffix_removal']
|
76 |
+
self.catalog_numerical_only = self.cfg['leafmachine']['project']['catalog_numerical_only']
|
77 |
+
|
78 |
+
self.prompt_version0 = self.cfg['leafmachine']['project']['prompt_version']
|
79 |
+
self.use_domain_knowledge = self.cfg['leafmachine']['project']['use_domain_knowledge']
|
80 |
+
|
81 |
+
self.catalog_name_options = ["Catalog Number", "catalog_number"]
|
82 |
+
|
83 |
+
self.utility_headers = ["filename",
|
84 |
+
"WFO_override_OCR", "WFO_exact_match","WFO_exact_match_name","WFO_best_match","WFO_candidate_names","WFO_placement",
|
85 |
+
|
86 |
+
"GEO_override_OCR", "GEO_method", "GEO_formatted_full_string", "GEO_decimal_lat",
|
87 |
+
"GEO_decimal_long","GEO_city", "GEO_county", "GEO_state",
|
88 |
+
"GEO_state_code", "GEO_country", "GEO_country_code", "GEO_continent",
|
89 |
+
|
90 |
+
"tokens_in", "tokens_out", "path_to_crop","path_to_original","path_to_content","path_to_helper",]
|
91 |
+
|
92 |
+
self.do_create_OCR_helper_image = self.cfg['leafmachine']['do_create_OCR_helper_image']
|
93 |
+
|
94 |
+
self.map_prompt_versions()
|
95 |
+
self.map_dir_labels()
|
96 |
+
self.map_API_options()
|
97 |
+
# self.init_embeddings()
|
98 |
+
self.init_transcription_xlsx()
|
99 |
+
self.init_trOCR_model()
|
100 |
+
|
101 |
+
'''Logging'''
|
102 |
+
self.logger.info(f'Transcribing dataset --- {self.dir_labels}')
|
103 |
+
self.logger.info(f'Saving transcription batch to --- {self.path_transcription}')
|
104 |
+
self.logger.info(f'Saving individual transcription files to --- {self.Dirs.transcription_ind}')
|
105 |
+
self.logger.info(f'Starting transcription...')
|
106 |
+
self.logger.info(f' LLM MODEL --> {self.version_name}')
|
107 |
+
self.logger.info(f' Using Azure API --> {self.is_azure}')
|
108 |
+
self.logger.info(f' Model name passed to API --> {self.model_name}')
|
109 |
+
self.logger.info(f' API access token is found in PRIVATE_DATA.yaml --> {self.has_key}')
|
110 |
+
|
111 |
+
def init_trOCR_model(self):
|
112 |
+
lgr = logging.getLogger('transformers')
|
113 |
+
lgr.setLevel(logging.ERROR)
|
114 |
+
|
115 |
+
self.trOCR_processor = TrOCRProcessor.from_pretrained(self.trOCR_model_version)
|
116 |
+
self.trOCR_model = VisionEncoderDecoderModel.from_pretrained(self.trOCR_model_version)
|
117 |
+
|
118 |
+
# Check for GPU availability
|
119 |
+
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
120 |
+
self.trOCR_model.to(self.device)
|
121 |
+
|
122 |
+
def map_API_options(self):
|
123 |
+
self.chat_version = self.cfg['leafmachine']['LLM_version']
|
124 |
+
|
125 |
+
# Get the required values from ModelMaps
|
126 |
+
self.model_name = ModelMaps.get_version_mapping_cost(self.chat_version)
|
127 |
+
self.is_azure = ModelMaps.get_version_mapping_is_azure(self.chat_version)
|
128 |
+
self.has_key = ModelMaps.get_version_has_key(self.chat_version, self.has_key_openai, self.has_key_azure_openai, self.has_key_palm2, self.has_key_mistral)
|
129 |
+
|
130 |
+
# Check if the version is supported
|
131 |
+
if self.model_name is None:
|
132 |
+
supported_LLMs = ", ".join(ModelMaps.get_models_gui_list())
|
133 |
+
raise Exception(f"Unsupported LLM: {self.chat_version}. Requires one of: {supported_LLMs}")
|
134 |
+
|
135 |
+
self.version_name = self.chat_version
|
136 |
+
|
137 |
+
def map_prompt_versions(self):
|
138 |
+
self.prompt_version_map = {
|
139 |
+
"Version 1": "prompt_v1_verbose",
|
140 |
+
"Version 1 No Domain Knowledge": "prompt_v1_verbose_noDomainKnowledge",
|
141 |
+
"Version 2": "prompt_v2_json_rules",
|
142 |
+
"Version 1 PaLM 2": 'prompt_v1_palm2',
|
143 |
+
"Version 1 PaLM 2 No Domain Knowledge": 'prompt_v1_palm2_noDomainKnowledge',
|
144 |
+
"Version 2 PaLM 2": 'prompt_v2_palm2',
|
145 |
+
}
|
146 |
+
self.prompt_version = self.prompt_version_map.get(self.prompt_version0, self.path_custom_prompts)
|
147 |
+
self.is_predefined_prompt = self.is_in_prompt_version_map(self.prompt_version)
|
148 |
+
|
149 |
+
def is_in_prompt_version_map(self, value):
|
150 |
+
return value in self.prompt_version_map.values()
|
151 |
+
|
152 |
+
# def init_embeddings(self):
|
153 |
+
# if self.use_domain_knowledge:
|
154 |
+
# self.logger.info(f'*** USING DOMAIN KNOWLEDGE ***')
|
155 |
+
# self.logger.info(f'*** Initializing vector embeddings database ***')
|
156 |
+
# self.initialize_embeddings()
|
157 |
+
# else:
|
158 |
+
# self.Voucher_Vision_Embedding = None
|
159 |
+
|
160 |
+
def map_dir_labels(self):
|
161 |
+
if self.cfg['leafmachine']['use_RGB_label_images']:
|
162 |
+
self.dir_labels = os.path.join(self.Dirs.save_per_annotation_class,'label')
|
163 |
+
else:
|
164 |
+
self.dir_labels = self.Dirs.save_original
|
165 |
+
|
166 |
+
# Use glob to get all image paths in the directory
|
167 |
+
self.img_paths = glob.glob(os.path.join(self.dir_labels, "*"))
|
168 |
+
|
169 |
+
def load_rules_config(self):
|
170 |
+
with open(self.path_custom_prompts, 'r') as stream:
|
171 |
+
try:
|
172 |
+
return yaml.safe_load(stream)
|
173 |
+
except yaml.YAMLError as exc:
|
174 |
+
print(exc)
|
175 |
+
return None
|
176 |
+
|
177 |
+
def generate_xlsx_headers(self):
|
178 |
+
# Extract headers from the 'Dictionary' keys in the JSON template rules
|
179 |
+
# xlsx_headers = list(self.rules_config_json['rules']["Dictionary"].keys())
|
180 |
+
xlsx_headers = list(self.rules_config_json['rules'].keys())
|
181 |
+
xlsx_headers = xlsx_headers + self.utility_headers
|
182 |
+
return xlsx_headers
|
183 |
+
|
184 |
+
def init_transcription_xlsx(self):
|
185 |
+
self.HEADERS_v1_n22 = ["Catalog Number","Genus","Species","subspecies","variety","forma","Country","State","County","Locality Name","Min Elevation","Max Elevation","Elevation Units","Verbatim Coordinates","Datum","Cultivated","Habitat","Collectors","Collector Number","Verbatim Date","Date","End Date"]
|
186 |
+
self.HEADERS_v2_n26 = ["catalog_number","genus","species","subspecies","variety","forma","country","state","county","locality_name","min_elevation","max_elevation","elevation_units","verbatim_coordinates","decimal_coordinates","datum","cultivated","habitat","plant_description","collectors","collector_number","determined_by","multiple_names","verbatim_date","date","end_date"]
|
187 |
+
self.HEADERS_v1_n22 = self.HEADERS_v1_n22 + self.utility_headers
|
188 |
+
self.HEADERS_v2_n26 = self.HEADERS_v2_n26 + self.utility_headers
|
189 |
+
# Initialize output file
|
190 |
+
self.path_transcription = os.path.join(self.Dirs.transcription,"transcribed.xlsx")
|
191 |
+
|
192 |
+
if self.prompt_version in ['prompt_v2_json_rules','prompt_v2_palm2']:
|
193 |
+
self.headers = self.HEADERS_v2_n26
|
194 |
+
self.headers_used = 'HEADERS_v2_n26'
|
195 |
+
|
196 |
+
elif self.prompt_version in ['prompt_v1_verbose', 'prompt_v1_verbose_noDomainKnowledge','prompt_v1_palm2', 'prompt_v1_palm2_noDomainKnowledge']:
|
197 |
+
self.headers = self.HEADERS_v1_n22
|
198 |
+
self.headers_used = 'HEADERS_v1_n22'
|
199 |
+
|
200 |
+
else:
|
201 |
+
if not self.is_predefined_prompt:
|
202 |
+
# Load the rules configuration
|
203 |
+
self.rules_config_json = self.load_rules_config()
|
204 |
+
# Generate the headers from the configuration
|
205 |
+
self.headers = self.generate_xlsx_headers()
|
206 |
+
# Set the headers used to the dynamically generated headers
|
207 |
+
self.headers_used = 'CUSTOM'
|
208 |
+
else:
|
209 |
+
# If it's a predefined prompt, raise an exception as we don't have further instructions
|
210 |
+
raise ValueError("Predefined prompt is not handled in this context.")
|
211 |
+
|
212 |
+
self.create_or_load_excel_with_headers(os.path.join(self.Dirs.transcription,"transcribed.xlsx"), self.headers)
|
213 |
+
|
214 |
+
|
215 |
+
def create_or_load_excel_with_headers(self, file_path, headers, show_head=False):
|
216 |
+
output_dir_names = ['Archival_Components', 'Config_File', 'Cropped_Images', 'Logs', 'Original_Images', 'Transcription']
|
217 |
+
self.completed_specimens = []
|
218 |
+
|
219 |
+
# Check if the file exists and it's not None
|
220 |
+
if self.continue_run_from_partial_xlsx is not None and os.path.isfile(self.continue_run_from_partial_xlsx):
|
221 |
+
workbook = load_workbook(filename=self.continue_run_from_partial_xlsx)
|
222 |
+
sheet = workbook.active
|
223 |
+
show_head=True
|
224 |
+
# Identify the 'path_to_crop' column
|
225 |
+
try:
|
226 |
+
path_to_crop_col = headers.index('path_to_crop') + 1
|
227 |
+
path_to_original_col = headers.index('path_to_original') + 1
|
228 |
+
path_to_content_col = headers.index('path_to_content') + 1
|
229 |
+
path_to_helper_col = headers.index('path_to_helper') + 1
|
230 |
+
# self.completed_specimens = list(sheet.iter_cols(min_col=path_to_crop_col, max_col=path_to_crop_col, values_only=True, min_row=2))
|
231 |
+
except ValueError:
|
232 |
+
print("'path_to_crop' not found in the header row.")
|
233 |
+
|
234 |
+
|
235 |
+
path_to_crop = list(sheet.iter_cols(min_col=path_to_crop_col, max_col=path_to_crop_col, values_only=True, min_row=2))
|
236 |
+
path_to_original = list(sheet.iter_cols(min_col=path_to_original_col, max_col=path_to_original_col, values_only=True, min_row=2))
|
237 |
+
path_to_content = list(sheet.iter_cols(min_col=path_to_content_col, max_col=path_to_content_col, values_only=True, min_row=2))
|
238 |
+
path_to_helper = list(sheet.iter_cols(min_col=path_to_helper_col, max_col=path_to_helper_col, values_only=True, min_row=2))
|
239 |
+
others = [path_to_crop_col, path_to_original_col, path_to_content_col, path_to_helper_col]
|
240 |
+
jsons = [path_to_content_col, path_to_helper_col]
|
241 |
+
|
242 |
+
for cell in path_to_crop[0]:
|
243 |
+
old_path = cell
|
244 |
+
new_path = file_path
|
245 |
+
for dir_name in output_dir_names:
|
246 |
+
if dir_name in old_path:
|
247 |
+
old_path_parts = old_path.split(dir_name)
|
248 |
+
new_path_parts = new_path.split('Transcription')
|
249 |
+
updated_path = new_path_parts[0] + dir_name + old_path_parts[1]
|
250 |
+
self.completed_specimens.append(os.path.basename(updated_path))
|
251 |
+
print(f"{len(self.completed_specimens)} images are already completed")
|
252 |
+
|
253 |
+
### Copy the JSON files over
|
254 |
+
for colu in jsons:
|
255 |
+
cell = next(sheet.iter_rows(min_row=2, min_col=colu, max_col=colu))[0]
|
256 |
+
old_path = cell.value
|
257 |
+
new_path = file_path
|
258 |
+
|
259 |
+
old_path_parts = old_path.split('Transcription')
|
260 |
+
new_path_parts = new_path.split('Transcription')
|
261 |
+
updated_path = new_path_parts[0] + 'Transcription' + old_path_parts[1]
|
262 |
+
|
263 |
+
# Copy files
|
264 |
+
old_dir = os.path.dirname(old_path)
|
265 |
+
new_dir = os.path.dirname(updated_path)
|
266 |
+
|
267 |
+
# Check if old_dir exists and it's a directory
|
268 |
+
if os.path.exists(old_dir) and os.path.isdir(old_dir):
|
269 |
+
# Check if new_dir exists. If not, create it.
|
270 |
+
if not os.path.exists(new_dir):
|
271 |
+
os.makedirs(new_dir)
|
272 |
+
|
273 |
+
# Iterate through all files in old_dir and copy each to new_dir
|
274 |
+
for filename in os.listdir(old_dir):
|
275 |
+
shutil.copy2(os.path.join(old_dir, filename), new_dir) # copy2 preserves metadata
|
276 |
+
|
277 |
+
### Update the file names
|
278 |
+
for colu in others:
|
279 |
+
for row in sheet.iter_rows(min_row=2, min_col=colu, max_col=colu):
|
280 |
+
for cell in row:
|
281 |
+
old_path = cell.value
|
282 |
+
new_path = file_path
|
283 |
+
for dir_name in output_dir_names:
|
284 |
+
if dir_name in old_path:
|
285 |
+
old_path_parts = old_path.split(dir_name)
|
286 |
+
new_path_parts = new_path.split('Transcription')
|
287 |
+
updated_path = new_path_parts[0] + dir_name + old_path_parts[1]
|
288 |
+
cell.value = updated_path
|
289 |
+
show_head=True
|
290 |
+
|
291 |
+
|
292 |
+
else:
|
293 |
+
# Create a new workbook and select the active worksheet
|
294 |
+
workbook = Workbook()
|
295 |
+
sheet = workbook.active
|
296 |
+
|
297 |
+
# Write headers in the first row
|
298 |
+
for i, header in enumerate(headers, start=1):
|
299 |
+
sheet.cell(row=1, column=i, value=header)
|
300 |
+
self.completed_specimens = []
|
301 |
+
|
302 |
+
# Save the workbook
|
303 |
+
workbook.save(file_path)
|
304 |
+
|
305 |
+
if show_head:
|
306 |
+
print("continue_run_from_partial_xlsx:")
|
307 |
+
for i, row in enumerate(sheet.iter_rows(values_only=True)):
|
308 |
+
print(row)
|
309 |
+
if i == 3: # print the first 5 rows (0-indexed)
|
310 |
+
print("\n")
|
311 |
+
break
|
312 |
+
|
313 |
+
|
314 |
+
|
315 |
+
def add_data_to_excel_from_response(self, path_transcription, response, WFO_record, GEO_record, filename_without_extension, path_to_crop, path_to_content, path_to_helper, nt_in, nt_out):
|
316 |
+
geo_headers = ["GEO_override_OCR", "GEO_method", "GEO_formatted_full_string", "GEO_decimal_lat",
|
317 |
+
"GEO_decimal_long","GEO_city", "GEO_county", "GEO_state",
|
318 |
+
"GEO_state_code", "GEO_country", "GEO_country_code", "GEO_continent",]
|
319 |
+
|
320 |
+
# WFO_candidate_names is separate, bc it may be type --> list
|
321 |
+
wfo_headers = ["WFO_override_OCR", "WFO_exact_match","WFO_exact_match_name","WFO_best_match","WFO_placement"]
|
322 |
+
|
323 |
+
wb = openpyxl.load_workbook(path_transcription)
|
324 |
+
sheet = wb.active
|
325 |
+
|
326 |
+
# find the next empty row
|
327 |
+
next_row = sheet.max_row + 1
|
328 |
+
|
329 |
+
if isinstance(response, str):
|
330 |
+
try:
|
331 |
+
response = json.loads(response)
|
332 |
+
except json.JSONDecodeError:
|
333 |
+
print(f"Failed to parse response: {response}")
|
334 |
+
return
|
335 |
+
|
336 |
+
# iterate over headers in the first row
|
337 |
+
for i, header in enumerate(sheet[1], start=1):
|
338 |
+
# check if header value is in response keys
|
339 |
+
if (header.value in response) and (header.value not in self.catalog_name_options): ####################### Catalog Number pre-defined
|
340 |
+
# check if the response value is a dictionary
|
341 |
+
if isinstance(response[header.value], dict):
|
342 |
+
# if it is a dictionary, extract the 'value' field
|
343 |
+
cell_value = response[header.value].get('value', '')
|
344 |
+
else:
|
345 |
+
# if it's not a dictionary, use it directly
|
346 |
+
cell_value = response[header.value]
|
347 |
+
|
348 |
+
try:
|
349 |
+
# write the value to the cell
|
350 |
+
sheet.cell(row=next_row, column=i, value=cell_value)
|
351 |
+
except:
|
352 |
+
sheet.cell(row=next_row, column=i, value=cell_value[0])
|
353 |
+
|
354 |
+
elif header.value in self.catalog_name_options:
|
355 |
+
# if self.prefix_removal:
|
356 |
+
# filename_without_extension = filename_without_extension.replace(self.prefix_removal, "")
|
357 |
+
# if self.suffix_removal:
|
358 |
+
# filename_without_extension = filename_without_extension.replace(self.suffix_removal, "")
|
359 |
+
# if self.catalog_numerical_only:
|
360 |
+
# filename_without_extension = self.remove_non_numbers(filename_without_extension)
|
361 |
+
sheet.cell(row=next_row, column=i, value=filename_without_extension)
|
362 |
+
elif header.value == "path_to_crop":
|
363 |
+
sheet.cell(row=next_row, column=i, value=path_to_crop)
|
364 |
+
elif header.value == "path_to_original":
|
365 |
+
if self.cfg['leafmachine']['use_RGB_label_images']:
|
366 |
+
fname = os.path.basename(path_to_crop)
|
367 |
+
base = os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(path_to_crop))))
|
368 |
+
path_to_original = os.path.join(base, 'Original_Images', fname)
|
369 |
+
sheet.cell(row=next_row, column=i, value=path_to_original)
|
370 |
+
else:
|
371 |
+
fname = os.path.basename(path_to_crop)
|
372 |
+
base = os.path.dirname(os.path.dirname(path_to_crop))
|
373 |
+
path_to_original = os.path.join(base, 'Original_Images', fname)
|
374 |
+
sheet.cell(row=next_row, column=i, value=path_to_original)
|
375 |
+
elif header.value == "path_to_content":
|
376 |
+
sheet.cell(row=next_row, column=i, value=path_to_content)
|
377 |
+
elif header.value == "path_to_helper":
|
378 |
+
sheet.cell(row=next_row, column=i, value=path_to_helper)
|
379 |
+
elif header.value == "tokens_in":
|
380 |
+
sheet.cell(row=next_row, column=i, value=nt_in)
|
381 |
+
elif header.value == "tokens_out":
|
382 |
+
sheet.cell(row=next_row, column=i, value=nt_out)
|
383 |
+
elif header.value == "filename":
|
384 |
+
sheet.cell(row=next_row, column=i, value=filename_without_extension)
|
385 |
+
|
386 |
+
# "WFO_exact_match","WFO_exact_match_name","WFO_best_match","WFO_candidate_names","WFO_placement"
|
387 |
+
elif header.value in wfo_headers:
|
388 |
+
sheet.cell(row=next_row, column=i, value=WFO_record.get(header.value, ''))
|
389 |
+
# elif header.value == "WFO_exact_match":
|
390 |
+
# sheet.cell(row=next_row, column=i, value= WFO_record.get("WFO_exact_match",''))
|
391 |
+
# elif header.value == "WFO_exact_match_name":
|
392 |
+
# sheet.cell(row=next_row, column=i, value= WFO_record.get("WFO_exact_match_name",''))
|
393 |
+
# elif header.value == "WFO_best_match":
|
394 |
+
# sheet.cell(row=next_row, column=i, value= WFO_record.get("WFO_best_match",''))
|
395 |
+
# elif header.value == "WFO_placement":
|
396 |
+
# sheet.cell(row=next_row, column=i, value= WFO_record.get("WFO_placement",''))
|
397 |
+
elif header.value == "WFO_candidate_names":
|
398 |
+
candidate_names = WFO_record.get("WFO_candidate_names", '')
|
399 |
+
# Check if candidate_names is a list and convert to a string if it is
|
400 |
+
if isinstance(candidate_names, list):
|
401 |
+
candidate_names_str = '|'.join(candidate_names)
|
402 |
+
else:
|
403 |
+
candidate_names_str = candidate_names
|
404 |
+
sheet.cell(row=next_row, column=i, value=candidate_names_str)
|
405 |
+
|
406 |
+
# "GEO_method", "GEO_formatted_full_string", "GEO_decimal_lat", "GEO_decimal_long",
|
407 |
+
# "GEO_city", "GEO_county", "GEO_state", "GEO_state_code", "GEO_country", "GEO_country_code", "GEO_continent"
|
408 |
+
elif header.value in geo_headers:
|
409 |
+
sheet.cell(row=next_row, column=i, value=GEO_record.get(header.value, ''))
|
410 |
+
|
411 |
+
# save the workbook
|
412 |
+
wb.save(path_transcription)
|
413 |
+
|
414 |
+
|
415 |
+
def has_API_key(self, val):
|
416 |
+
if val != '':
|
417 |
+
return True
|
418 |
+
else:
|
419 |
+
return False
|
420 |
+
|
421 |
+
|
422 |
+
def set_API_keys(self):
|
423 |
+
self.dir_home = os.path.dirname(os.path.dirname(__file__))
|
424 |
+
self.path_cfg_private = os.path.join(self.dir_home, 'PRIVATE_DATA.yaml')
|
425 |
+
self.cfg_private = get_cfg_from_full_path(self.path_cfg_private)
|
426 |
+
|
427 |
+
self.has_key_openai = self.has_API_key(self.cfg_private['openai']['OPENAI_API_KEY'])
|
428 |
+
|
429 |
+
self.has_key_azure_openai = self.has_API_key(self.cfg_private['openai_azure']['api_version'])
|
430 |
+
|
431 |
+
self.has_key_palm2 = self.has_API_key(self.cfg_private['google_palm']['google_palm_api'])
|
432 |
+
|
433 |
+
self.has_key_google_OCR = self.has_API_key(self.cfg_private['google_cloud']['path_json_file'])
|
434 |
+
|
435 |
+
self.has_key_mistral = self.has_API_key(self.cfg_private['mistral']['mistral_key'])
|
436 |
+
|
437 |
+
self.has_key_here = self.has_API_key(self.cfg_private['here']['api_key'])
|
438 |
+
|
439 |
+
self.has_open_cage_geocode = self.has_API_key(self.cfg_private['open_cage_geocode']['api_key'])
|
440 |
+
|
441 |
+
if self.has_key_openai:
|
442 |
+
openai.api_key = self.cfg_private['openai']['OPENAI_API_KEY']
|
443 |
+
os.environ["OPENAI_API_KEY"] = self.cfg_private['openai']['OPENAI_API_KEY']
|
444 |
+
|
445 |
+
if self.has_key_azure_openai:
|
446 |
+
# os.environ["OPENAI_API_KEY"] = self.cfg_private['openai_azure']['openai_api_key']
|
447 |
+
self.llm = AzureChatOpenAI(
|
448 |
+
deployment_name = 'gpt-35-turbo',#'gpt-35-turbo',
|
449 |
+
openai_api_version = self.cfg_private['openai_azure']['api_version'],
|
450 |
+
openai_api_key = self.cfg_private['openai_azure']['openai_api_key'],
|
451 |
+
azure_endpoint = self.cfg_private['openai_azure']['openai_api_base'],
|
452 |
+
# openai_api_base=self.cfg_private['openai_azure']['openai_api_base'],
|
453 |
+
openai_organization = self.cfg_private['openai_azure']['openai_organization'],
|
454 |
+
# openai_api_type = self.cfg_private['openai_azure']['openai_api_type']
|
455 |
+
)
|
456 |
+
|
457 |
+
# Enable this for all LLMs EXCEPT GOOGLE
|
458 |
+
name_check = self.cfg['leafmachine']['LLM_version'].lower().split(' ')
|
459 |
+
if 'google' in name_check:
|
460 |
+
pass
|
461 |
+
else:
|
462 |
+
if self.has_key_google_OCR:
|
463 |
+
if os.path.exists(self.cfg_private['google_cloud']['path_json_file']):
|
464 |
+
os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = self.cfg_private['google_cloud']['path_json_file']
|
465 |
+
elif os.path.exists(self.cfg_private['google_cloud']['path_json_file_service_account2']):
|
466 |
+
os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = self.cfg_private['google_cloud']['path_json_file_service_account2']
|
467 |
+
else:
|
468 |
+
raise f"Google JSON API key file not found"
|
469 |
+
|
470 |
+
if self.has_key_palm2:
|
471 |
+
os.environ['PALM'] = self.cfg_private['google_palm']['google_palm_api']
|
472 |
+
os.environ['PALM_PROJECT_ID'] = self.cfg_private['google_palm']['project_id']
|
473 |
+
os.environ['PALM_LOCATION'] = self.cfg_private['google_palm']['location']
|
474 |
+
os.environ['GOOGLE_API_KEY'] = self.cfg_private['google_palm']['google_palm_api']
|
475 |
+
|
476 |
+
# os.environ["GOOGLE_SERVICE_ACCOUNT"] = self.cfg_private['google_cloud']['path_json_file_service_account']
|
477 |
+
# vertexai.init(project='directed-curve-401601', location='us-central1')
|
478 |
+
|
479 |
+
# genai.configure(api_key=os.environ['PALM'])
|
480 |
+
|
481 |
+
if self.has_key_palm2:
|
482 |
+
os.environ['MISTRAL_API_KEY'] = self.cfg_private['mistral']['mistral_key']
|
483 |
+
|
484 |
+
if self.has_key_here:
|
485 |
+
os.environ['here_app_id'] = self.cfg_private['here']['app_id']
|
486 |
+
os.environ['here_api_key'] = self.cfg_private['here']['api_key']
|
487 |
+
|
488 |
+
if self.has_open_cage_geocode:
|
489 |
+
os.environ['open_cage_geocode'] = self.cfg_private['open_cage_geocode']['api_key']
|
490 |
+
|
491 |
+
|
492 |
+
|
493 |
+
# def initialize_embeddings(self):
|
494 |
+
# '''Loading embedding search __init__(self, db_name, path_domain_knowledge, logger, build_new_db=False, model_name="hkunlp/instructor-xl", device="cuda")'''
|
495 |
+
# self.Voucher_Vision_Embedding = VoucherVisionEmbedding(self.db_name, self.path_domain_knowledge, logger=self.logger, build_new_db=self.build_new_db)
|
496 |
+
|
497 |
+
|
498 |
+
def clean_catalog_number(self, data, filename_without_extension):
|
499 |
+
#Cleans up the catalog number in data if it's a dict
|
500 |
+
|
501 |
+
def modify_catalog_key(catalog_key, filename_without_extension, data):
|
502 |
+
# Helper function to apply modifications on catalog number
|
503 |
+
if catalog_key not in data:
|
504 |
+
new_data = {catalog_key: None}
|
505 |
+
data = {**new_data, **data}
|
506 |
+
|
507 |
+
if self.prefix_removal:
|
508 |
+
filename_without_extension = filename_without_extension.replace(self.prefix_removal, "")
|
509 |
+
if self.suffix_removal:
|
510 |
+
filename_without_extension = filename_without_extension.replace(self.suffix_removal, "")
|
511 |
+
if self.catalog_numerical_only:
|
512 |
+
filename_without_extension = self.remove_non_numbers(data[catalog_key])
|
513 |
+
data[catalog_key] = filename_without_extension
|
514 |
+
return data
|
515 |
+
|
516 |
+
if isinstance(data, dict):
|
517 |
+
if self.headers_used == 'HEADERS_v1_n22':
|
518 |
+
return modify_catalog_key("Catalog Number", filename_without_extension, data)
|
519 |
+
elif self.headers_used in ['HEADERS_v2_n26', 'CUSTOM']:
|
520 |
+
return modify_catalog_key("filename", filename_without_extension, data)
|
521 |
+
else:
|
522 |
+
raise ValueError("Invalid headers used.")
|
523 |
+
else:
|
524 |
+
raise TypeError("Data is not of type dict.")
|
525 |
+
|
526 |
+
|
527 |
+
def write_json_to_file(self, filepath, data):
|
528 |
+
'''Writes dictionary data to a JSON file.'''
|
529 |
+
with open(filepath, 'w') as txt_file:
|
530 |
+
if isinstance(data, dict):
|
531 |
+
data = json.dumps(data, indent=4, sort_keys=False)
|
532 |
+
txt_file.write(data)
|
533 |
+
|
534 |
+
|
535 |
+
# def create_null_json(self):
|
536 |
+
# return {}
|
537 |
+
|
538 |
+
|
539 |
+
def remove_non_numbers(self, s):
|
540 |
+
return ''.join([char for char in s if char.isdigit()])
|
541 |
+
|
542 |
+
|
543 |
+
def create_null_row(self, filename_without_extension, path_to_crop, path_to_content, path_to_helper):
|
544 |
+
json_dict = {header: '' for header in self.headers}
|
545 |
+
for header, value in json_dict.items():
|
546 |
+
if header == "path_to_crop":
|
547 |
+
json_dict[header] = path_to_crop
|
548 |
+
elif header == "path_to_original":
|
549 |
+
fname = os.path.basename(path_to_crop)
|
550 |
+
base = os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(path_to_crop))))
|
551 |
+
path_to_original = os.path.join(base, 'Original_Images', fname)
|
552 |
+
json_dict[header] = path_to_original
|
553 |
+
elif header == "path_to_content":
|
554 |
+
json_dict[header] = path_to_content
|
555 |
+
elif header == "path_to_helper":
|
556 |
+
json_dict[header] = path_to_helper
|
557 |
+
elif header == "filename":
|
558 |
+
json_dict[header] = filename_without_extension
|
559 |
+
|
560 |
+
# "WFO_exact_match","WFO_exact_match_name","WFO_best_match","WFO_candidate_names","WFO_placement"
|
561 |
+
elif header == "WFO_exact_match":
|
562 |
+
json_dict[header] =''
|
563 |
+
elif header == "WFO_exact_match_name":
|
564 |
+
json_dict[header] = ''
|
565 |
+
elif header == "WFO_best_match":
|
566 |
+
json_dict[header] = ''
|
567 |
+
elif header == "WFO_candidate_names":
|
568 |
+
json_dict[header] = ''
|
569 |
+
elif header == "WFO_placement":
|
570 |
+
json_dict[header] = ''
|
571 |
+
return json_dict
|
572 |
+
|
573 |
+
|
574 |
+
##################################################################################################################################
|
575 |
+
################################################## OCR ##################################################################
|
576 |
+
##################################################################################################################################
|
577 |
+
def perform_OCR_and_save_results(self, image_index, jpg_file_path_OCR_helper, txt_file_path_OCR, txt_file_path_OCR_bounds):
|
578 |
+
self.logger.info(f'Working on {image_index + 1}/{len(self.img_paths)} --- Starting OCR')
|
579 |
+
# self.OCR - None
|
580 |
+
|
581 |
+
### Process_image() runs the OCR for text, handwriting, trOCR AND creates the overlay image
|
582 |
+
ocr_google = OCRGoogle(self.path_to_crop, self.cfg, self.trOCR_model_version, self.trOCR_model, self.trOCR_processor, self.device)
|
583 |
+
ocr_google.process_image(self.do_create_OCR_helper_image, self.logger)
|
584 |
+
self.OCR = ocr_google.OCR
|
585 |
+
|
586 |
+
self.write_json_to_file(txt_file_path_OCR, ocr_google.OCR_JSON_to_file)
|
587 |
+
|
588 |
+
self.logger.info(f'Working on {image_index + 1}/{len(self.img_paths)} --- Finished OCR')
|
589 |
+
|
590 |
+
if len(self.OCR) > 0:
|
591 |
+
ocr_google.overlay_image.save(jpg_file_path_OCR_helper)
|
592 |
+
|
593 |
+
OCR_bounds = {}
|
594 |
+
if ocr_google.hand_text_to_box_mapping is not None:
|
595 |
+
OCR_bounds['OCR_bounds_handwritten'] = ocr_google.hand_text_to_box_mapping
|
596 |
+
|
597 |
+
if ocr_google.normal_text_to_box_mapping is not None:
|
598 |
+
OCR_bounds['OCR_bounds_printed'] = ocr_google.normal_text_to_box_mapping
|
599 |
+
|
600 |
+
if ocr_google.trOCR_text_to_box_mapping is not None:
|
601 |
+
OCR_bounds['OCR_bounds_trOCR'] = ocr_google.trOCR_text_to_box_mapping
|
602 |
+
|
603 |
+
self.write_json_to_file(txt_file_path_OCR_bounds, OCR_bounds)
|
604 |
+
self.logger.info(f'Working on {image_index + 1}/{len(self.img_paths)} --- Saved OCR Overlay Image')
|
605 |
+
else:
|
606 |
+
pass ########################################################################################################################### fix logic for no OCR
|
607 |
+
|
608 |
+
##################################################################################################################################
|
609 |
+
####################################################### LLM Switchboard ########################################################
|
610 |
+
##################################################################################################################################
|
611 |
+
def send_to_LLM(self, is_azure, progress_report, json_report, model_name):
|
612 |
+
self.n_failed_LLM_calls = 0
|
613 |
+
self.n_failed_OCR = 0
|
614 |
+
|
615 |
+
final_JSON_response = None
|
616 |
+
final_WFO_record = None
|
617 |
+
final_GEO_record = None
|
618 |
+
|
619 |
+
self.initialize_token_counters()
|
620 |
+
self.update_progress_report_initial(progress_report)
|
621 |
+
|
622 |
+
MODEL_NAME_FORMATTED = ModelMaps.get_API_name(model_name)
|
623 |
+
name_parts = model_name.split("_")
|
624 |
+
|
625 |
+
self.setup_JSON_dict_structure()
|
626 |
+
|
627 |
+
llm_model = self.initialize_llm_model(self.logger, MODEL_NAME_FORMATTED, self.JSON_dict_structure, name_parts, is_azure, self.llm)
|
628 |
+
|
629 |
+
prompts = []
|
630 |
+
for i, path_to_crop in enumerate(self.img_paths):
|
631 |
+
self.update_progress_report_batch(progress_report, i)
|
632 |
+
|
633 |
+
if self.should_skip_specimen(path_to_crop):
|
634 |
+
self.log_skipping_specimen(path_to_crop)
|
635 |
+
continue
|
636 |
+
|
637 |
+
paths = self.generate_paths(path_to_crop, i)
|
638 |
+
self.path_to_crop = path_to_crop
|
639 |
+
|
640 |
+
filename_without_extension, txt_file_path, txt_file_path_OCR, txt_file_path_OCR_bounds, jpg_file_path_OCR_helper = paths
|
641 |
+
self.perform_OCR_and_save_results(i, jpg_file_path_OCR_helper, txt_file_path_OCR, txt_file_path_OCR_bounds)
|
642 |
+
|
643 |
+
if not self.OCR:
|
644 |
+
self.n_failed_OCR += 1
|
645 |
+
else:
|
646 |
+
# Format prompt
|
647 |
+
prompt = self.setup_prompt()
|
648 |
+
prompt = remove_colons_and_double_apostrophes(prompt)
|
649 |
+
prompts.append(prompt)
|
650 |
+
|
651 |
+
# Process prompts in batch
|
652 |
+
if 'LOCAL' in name_parts and ('MISTRAL' in name_parts or 'MIXTRAL' in name_parts):
|
653 |
+
batch_results = llm_model.call_llm_local_MistralAI(prompts) # Assuming this method is updated to handle batch processing
|
654 |
+
|
655 |
+
# Process each result from the batch
|
656 |
+
for i, result in enumerate(batch_results):
|
657 |
+
response_candidate, nt_in, nt_out, WFO_record, GEO_record = result.values()
|
658 |
+
|
659 |
+
self.n_failed_LLM_calls += 1 if response_candidate is None else 0
|
660 |
+
|
661 |
+
# Estimate n tokens returned
|
662 |
+
self.logger.info(f'Prompt tokens IN --- {nt_in}')
|
663 |
+
self.logger.info(f'Prompt tokens OUT --- {nt_out}')
|
664 |
+
|
665 |
+
self.update_token_counters(nt_in, nt_out)
|
666 |
+
|
667 |
+
final_JSON_response, final_WFO_record, final_GEO_record = self.update_final_response(response_candidate, WFO_record, GEO_record, paths, path_to_crop, nt_in, nt_out)
|
668 |
+
|
669 |
+
self.log_completion_info(final_JSON_response)
|
670 |
+
|
671 |
+
json_report.set_JSON(final_JSON_response, final_WFO_record, final_GEO_record)
|
672 |
+
|
673 |
+
self.update_progress_report_final(progress_report)
|
674 |
+
final_JSON_response = self.parse_final_json_response(final_JSON_response)
|
675 |
+
return final_JSON_response, final_WFO_record, final_GEO_record, self.total_tokens_in, self.total_tokens_out
|
676 |
+
|
677 |
+
|
678 |
+
|
679 |
+
##################################################################################################################################
|
680 |
+
################################################## LLM Helper Funcs ##############################################################
|
681 |
+
##################################################################################################################################
|
682 |
+
def initialize_llm_model(self, logger, model_name, JSON_dict_structure, name_parts, is_azure=None, llm_object=None):
|
683 |
+
if 'LOCAL'in name_parts:
|
684 |
+
if ('MIXTRAL' in name_parts) or ('MISTRAL' in name_parts):
|
685 |
+
if 'CPU' in name_parts:
|
686 |
+
return LocalCPUMistralHandler(logger, model_name, JSON_dict_structure)
|
687 |
+
else:
|
688 |
+
return LocalMistralHandler(logger, model_name, JSON_dict_structure)
|
689 |
+
else:
|
690 |
+
if 'PALM2' in name_parts:
|
691 |
+
return GooglePalm2Handler(logger, model_name, JSON_dict_structure)
|
692 |
+
elif 'GEMINI' in name_parts:
|
693 |
+
return GoogleGeminiHandler(logger, model_name, JSON_dict_structure)
|
694 |
+
elif 'MISTRAL' in name_parts and ('LOCAL' not in name_parts):
|
695 |
+
return MistralHandler(logger, model_name, JSON_dict_structure)
|
696 |
+
else:
|
697 |
+
return OpenAIHandler(logger, model_name, JSON_dict_structure, is_azure, llm_object)
|
698 |
+
|
699 |
+
def setup_prompt(self):
|
700 |
+
Catalog = PromptCatalog()
|
701 |
+
prompt, _ = Catalog.prompt_SLTP(self.path_custom_prompts, OCR=self.OCR)
|
702 |
+
return prompt
|
703 |
+
|
704 |
+
def setup_JSON_dict_structure(self):
|
705 |
+
Catalog = PromptCatalog()
|
706 |
+
_, self.JSON_dict_structure = Catalog.prompt_SLTP(self.path_custom_prompts, OCR='Text')
|
707 |
+
|
708 |
+
|
709 |
+
def initialize_token_counters(self):
|
710 |
+
self.total_tokens_in = 0
|
711 |
+
self.total_tokens_out = 0
|
712 |
+
|
713 |
+
|
714 |
+
def update_progress_report_initial(self, progress_report):
|
715 |
+
if progress_report is not None:
|
716 |
+
progress_report.set_n_batches(len(self.img_paths))
|
717 |
+
|
718 |
+
|
719 |
+
def update_progress_report_batch(self, progress_report, batch_index):
|
720 |
+
if progress_report is not None:
|
721 |
+
progress_report.update_batch(f"Working on image {batch_index + 1} of {len(self.img_paths)}")
|
722 |
+
|
723 |
+
|
724 |
+
def should_skip_specimen(self, path_to_crop):
|
725 |
+
return os.path.basename(path_to_crop) in self.completed_specimens
|
726 |
+
|
727 |
+
|
728 |
+
def log_skipping_specimen(self, path_to_crop):
|
729 |
+
self.logger.info(f'[Skipping] specimen {os.path.basename(path_to_crop)} already processed')
|
730 |
+
|
731 |
+
|
732 |
+
def update_token_counters(self, nt_in, nt_out):
|
733 |
+
self.total_tokens_in += nt_in
|
734 |
+
self.total_tokens_out += nt_out
|
735 |
+
|
736 |
+
|
737 |
+
def update_final_response(self, response_candidate, WFO_record, GEO_record, paths, path_to_crop, nt_in, nt_out):
|
738 |
+
filename_without_extension, txt_file_path, txt_file_path_OCR, txt_file_path_OCR_bounds, jpg_file_path_OCR_helper = paths
|
739 |
+
# Saving the JSON and XLSX files with the response and updating the final JSON response
|
740 |
+
if response_candidate is not None:
|
741 |
+
final_JSON_response_updated = self.save_json_and_xlsx(response_candidate, WFO_record, GEO_record, filename_without_extension, path_to_crop, txt_file_path, jpg_file_path_OCR_helper, nt_in, nt_out)
|
742 |
+
return final_JSON_response_updated, WFO_record, GEO_record
|
743 |
+
else:
|
744 |
+
final_JSON_response_updated = self.save_json_and_xlsx(response_candidate, WFO_record, GEO_record, filename_without_extension, path_to_crop, txt_file_path, jpg_file_path_OCR_helper, nt_in, nt_out)
|
745 |
+
return final_JSON_response_updated, WFO_record, GEO_record
|
746 |
+
|
747 |
+
|
748 |
+
def log_completion_info(self, final_JSON_response):
|
749 |
+
self.logger.info(f'Formatted JSON\n{final_JSON_response}')
|
750 |
+
self.logger.info(f'Finished API calls\n')
|
751 |
+
|
752 |
+
|
753 |
+
def update_progress_report_final(self, progress_report):
|
754 |
+
if progress_report is not None:
|
755 |
+
progress_report.reset_batch("Batch Complete")
|
756 |
+
|
757 |
+
|
758 |
+
def parse_final_json_response(self, final_JSON_response):
|
759 |
+
try:
|
760 |
+
return json.loads(final_JSON_response.strip('```').replace('json\n', '', 1).replace('json', '', 1))
|
761 |
+
except:
|
762 |
+
return final_JSON_response
|
763 |
+
|
764 |
+
|
765 |
+
|
766 |
+
def generate_paths(self, path_to_crop, i):
|
767 |
+
filename_without_extension = os.path.splitext(os.path.basename(path_to_crop))[0]
|
768 |
+
txt_file_path = os.path.join(self.Dirs.transcription_ind, filename_without_extension + '.json')
|
769 |
+
txt_file_path_OCR = os.path.join(self.Dirs.transcription_ind_OCR, filename_without_extension + '.json')
|
770 |
+
txt_file_path_OCR_bounds = os.path.join(self.Dirs.transcription_ind_OCR_bounds, filename_without_extension + '.json')
|
771 |
+
jpg_file_path_OCR_helper = os.path.join(self.Dirs.transcription_ind_OCR_helper, filename_without_extension + '.jpg')
|
772 |
+
|
773 |
+
self.logger.info(f'Working on {i+1}/{len(self.img_paths)} --- {filename_without_extension}')
|
774 |
+
|
775 |
+
return filename_without_extension, txt_file_path, txt_file_path_OCR, txt_file_path_OCR_bounds, jpg_file_path_OCR_helper
|
776 |
+
|
777 |
+
|
778 |
+
def save_json_and_xlsx(self, response, WFO_record, GEO_record, filename_without_extension, path_to_crop, txt_file_path, jpg_file_path_OCR_helper, nt_in, nt_out):
|
779 |
+
if response is None:
|
780 |
+
response = self.JSON_dict_structure
|
781 |
+
# Insert 'filename' as the first key
|
782 |
+
response = {'filename': filename_without_extension, **{k: v for k, v in response.items() if k != 'filename'}}
|
783 |
+
self.write_json_to_file(txt_file_path, response)
|
784 |
+
|
785 |
+
# Then add the null info to the spreadsheet
|
786 |
+
response_null = self.create_null_row(filename_without_extension, path_to_crop, txt_file_path, jpg_file_path_OCR_helper)
|
787 |
+
self.add_data_to_excel_from_response(self.path_transcription, response_null, WFO_record, GEO_record, filename_without_extension, path_to_crop, txt_file_path, jpg_file_path_OCR_helper, nt_in=0, nt_out=0)
|
788 |
+
|
789 |
+
### Set completed JSON
|
790 |
+
else:
|
791 |
+
response = self.clean_catalog_number(response, filename_without_extension)
|
792 |
+
self.write_json_to_file(txt_file_path, response)
|
793 |
+
# add to the xlsx file
|
794 |
+
self.add_data_to_excel_from_response(self.path_transcription, response, WFO_record, GEO_record, filename_without_extension, path_to_crop, txt_file_path, jpg_file_path_OCR_helper, nt_in, nt_out)
|
795 |
+
return response
|
796 |
+
|
797 |
+
|
798 |
+
def process_specimen_batch(self, progress_report, json_report, is_real_run=False):
|
799 |
+
if not self.has_key:
|
800 |
+
self.logger.error(f'No API key found for {self.version_name}')
|
801 |
+
raise Exception(f"No API key found for {self.version_name}")
|
802 |
+
|
803 |
+
try:
|
804 |
+
if is_real_run:
|
805 |
+
progress_report.update_overall(f"Transcribing Labels")
|
806 |
+
|
807 |
+
final_json_response, final_WFO_record, final_GEO_record, total_tokens_in, total_tokens_out = self.send_to_LLM(self.is_azure, progress_report, json_report, self.model_name)
|
808 |
+
|
809 |
+
return final_json_response, final_WFO_record, final_GEO_record, total_tokens_in, total_tokens_out
|
810 |
+
|
811 |
+
except Exception as e:
|
812 |
+
self.logger.error(f"LLM call failed in process_specimen_batch: {e}")
|
813 |
+
if progress_report is not None:
|
814 |
+
progress_report.reset_batch(f"Batch Failed")
|
815 |
+
self.close_logger_handlers()
|
816 |
+
raise
|
817 |
+
|
818 |
+
|
819 |
+
def close_logger_handlers(self):
|
820 |
+
for handler in self.logger.handlers[:]:
|
821 |
+
handler.close()
|
822 |
+
self.logger.removeHandler(handler)
|
823 |
+
|
824 |
+
|
825 |
+
def process_specimen_batch_OCR_test(self, path_to_crop):
|
826 |
+
for img_filename in os.listdir(path_to_crop):
|
827 |
+
img_path = os.path.join(path_to_crop, img_filename)
|
828 |
+
self.OCR, self.bounds, self.text_to_box_mapping = detect_text(img_path)
|
829 |
+
|
830 |
+
|
831 |
+
|
832 |
+
def space_saver(cfg, Dirs, logger):
|
833 |
+
dir_out = cfg['leafmachine']['project']['dir_output']
|
834 |
+
run_name = Dirs.run_name
|
835 |
+
|
836 |
+
path_project = os.path.join(dir_out, run_name)
|
837 |
+
|
838 |
+
if cfg['leafmachine']['project']['delete_temps_keep_VVE']:
|
839 |
+
logger.name = '[DELETE TEMP FILES]'
|
840 |
+
logger.info("Deleting temporary files. Keeping files required for VoucherVisionEditor.")
|
841 |
+
delete_dirs = ['Archival_Components', 'Config_File']
|
842 |
+
for d in delete_dirs:
|
843 |
+
path_delete = os.path.join(path_project, d)
|
844 |
+
if os.path.exists(path_delete):
|
845 |
+
shutil.rmtree(path_delete)
|
846 |
+
|
847 |
+
elif cfg['leafmachine']['project']['delete_all_temps']:
|
848 |
+
logger.name = '[DELETE TEMP FILES]'
|
849 |
+
logger.info("Deleting ALL temporary files!")
|
850 |
+
delete_dirs = ['Archival_Components', 'Config_File', 'Original_Images', 'Cropped_Images']
|
851 |
+
for d in delete_dirs:
|
852 |
+
path_delete = os.path.join(path_project, d)
|
853 |
+
if os.path.exists(path_delete):
|
854 |
+
shutil.rmtree(path_delete)
|
855 |
+
|
856 |
+
# Delete the transctiption folder, but keep the xlsx
|
857 |
+
transcription_path = os.path.join(path_project, 'Transcription')
|
858 |
+
if os.path.exists(transcription_path):
|
859 |
+
for item in os.listdir(transcription_path):
|
860 |
+
item_path = os.path.join(transcription_path, item)
|
861 |
+
if os.path.isdir(item_path): # if the item is a directory
|
862 |
+
if os.path.exists(item_path):
|
863 |
+
shutil.rmtree(item_path) # delete the directory
|
vouchervision/utils_geolocate_HERE.py
ADDED
@@ -0,0 +1,319 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os, requests
|
2 |
+
import pycountry_convert as pc
|
3 |
+
import unicodedata
|
4 |
+
import pycountry_convert as pc
|
5 |
+
import warnings
|
6 |
+
|
7 |
+
|
8 |
+
def normalize_country_name(name):
|
9 |
+
return unicodedata.normalize('NFKD', name).encode('ASCII', 'ignore').decode('ASCII')
|
10 |
+
|
11 |
+
def get_continent(country_name):
|
12 |
+
warnings.filterwarnings("ignore", category=UserWarning, module='pycountry')
|
13 |
+
|
14 |
+
continent_code_to_name = {
|
15 |
+
"AF": "Africa",
|
16 |
+
"NA": "North America",
|
17 |
+
"OC": "Oceania",
|
18 |
+
"AN": "Antarctica",
|
19 |
+
"AS": "Asia",
|
20 |
+
"EU": "Europe",
|
21 |
+
"SA": "South America"
|
22 |
+
}
|
23 |
+
|
24 |
+
try:
|
25 |
+
normalized_country_name = normalize_country_name(country_name)
|
26 |
+
# Get country alpha2 code
|
27 |
+
country_code = pc.country_name_to_country_alpha2(normalized_country_name)
|
28 |
+
# Get continent code from country alpha2 code
|
29 |
+
continent_code = pc.country_alpha2_to_continent_code(country_code)
|
30 |
+
# Map the continent code to continent name
|
31 |
+
return continent_code_to_name.get(continent_code, '')
|
32 |
+
except Exception as e:
|
33 |
+
print(str(e))
|
34 |
+
return ''
|
35 |
+
|
36 |
+
def validate_coordinates_here(record, replace_if_success_geo=False):
|
37 |
+
forward_url = 'https://geocode.search.hereapi.com/v1/geocode'
|
38 |
+
reverse_url = 'https://revgeocode.search.hereapi.com/v1/revgeocode'
|
39 |
+
|
40 |
+
pinpoint = ['GEO_city','GEO_county','GEO_state','GEO_country',]
|
41 |
+
GEO_dict_null = {
|
42 |
+
'GEO_override_OCR': False,
|
43 |
+
'GEO_method': '',
|
44 |
+
'GEO_formatted_full_string': '',
|
45 |
+
'GEO_decimal_lat': '',
|
46 |
+
'GEO_decimal_long': '',
|
47 |
+
'GEO_city': '',
|
48 |
+
'GEO_county': '',
|
49 |
+
'GEO_state': '',
|
50 |
+
'GEO_state_code': '',
|
51 |
+
'GEO_country': '',
|
52 |
+
'GEO_country_code': '',
|
53 |
+
'GEO_continent': '',
|
54 |
+
}
|
55 |
+
GEO_dict = {
|
56 |
+
'GEO_override_OCR': False,
|
57 |
+
'GEO_method': '',
|
58 |
+
'GEO_formatted_full_string': '',
|
59 |
+
'GEO_decimal_lat': '',
|
60 |
+
'GEO_decimal_long': '',
|
61 |
+
'GEO_city': '',
|
62 |
+
'GEO_county': '',
|
63 |
+
'GEO_state': '',
|
64 |
+
'GEO_state_code': '',
|
65 |
+
'GEO_country': '',
|
66 |
+
'GEO_country_code': '',
|
67 |
+
'GEO_continent': '',
|
68 |
+
}
|
69 |
+
GEO_dict_rev = {
|
70 |
+
'GEO_override_OCR': False,
|
71 |
+
'GEO_method': '',
|
72 |
+
'GEO_formatted_full_string': '',
|
73 |
+
'GEO_decimal_lat': '',
|
74 |
+
'GEO_decimal_long': '',
|
75 |
+
'GEO_city': '',
|
76 |
+
'GEO_county': '',
|
77 |
+
'GEO_state': '',
|
78 |
+
'GEO_state_code': '',
|
79 |
+
'GEO_country': '',
|
80 |
+
'GEO_country_code': '',
|
81 |
+
'GEO_continent': '',
|
82 |
+
}
|
83 |
+
GEO_dict_rev_verbatim = {
|
84 |
+
'GEO_override_OCR': False,
|
85 |
+
'GEO_method': '',
|
86 |
+
'GEO_formatted_full_string': '',
|
87 |
+
'GEO_decimal_lat': '',
|
88 |
+
'GEO_decimal_long': '',
|
89 |
+
'GEO_city': '',
|
90 |
+
'GEO_county': '',
|
91 |
+
'GEO_state': '',
|
92 |
+
'GEO_state_code': '',
|
93 |
+
'GEO_country': '',
|
94 |
+
'GEO_country_code': '',
|
95 |
+
'GEO_continent': '',
|
96 |
+
}
|
97 |
+
GEO_dict_forward = {
|
98 |
+
'GEO_override_OCR': False,
|
99 |
+
'GEO_method': '',
|
100 |
+
'GEO_formatted_full_string': '',
|
101 |
+
'GEO_decimal_lat': '',
|
102 |
+
'GEO_decimal_long': '',
|
103 |
+
'GEO_city': '',
|
104 |
+
'GEO_county': '',
|
105 |
+
'GEO_state': '',
|
106 |
+
'GEO_state_code': '',
|
107 |
+
'GEO_country': '',
|
108 |
+
'GEO_country_code': '',
|
109 |
+
'GEO_continent': '',
|
110 |
+
}
|
111 |
+
GEO_dict_forward_locality = {
|
112 |
+
'GEO_override_OCR': False,
|
113 |
+
'GEO_method': '',
|
114 |
+
'GEO_formatted_full_string': '',
|
115 |
+
'GEO_decimal_lat': '',
|
116 |
+
'GEO_decimal_long': '',
|
117 |
+
'GEO_city': '',
|
118 |
+
'GEO_county': '',
|
119 |
+
'GEO_state': '',
|
120 |
+
'GEO_state_code': '',
|
121 |
+
'GEO_country': '',
|
122 |
+
'GEO_country_code': '',
|
123 |
+
'GEO_continent': '',
|
124 |
+
}
|
125 |
+
|
126 |
+
|
127 |
+
# For production
|
128 |
+
query_forward = ', '.join(filter(None, [record.get('municipality', '').strip(),
|
129 |
+
record.get('county', '').strip(),
|
130 |
+
record.get('stateProvince', '').strip(),
|
131 |
+
record.get('country', '').strip()])).strip()
|
132 |
+
query_forward_locality = ', '.join(filter(None, [record.get('locality', '').strip(),
|
133 |
+
record.get('municipality', '').strip(),
|
134 |
+
record.get('county', '').strip(),
|
135 |
+
record.get('stateProvince', '').strip(),
|
136 |
+
record.get('country', '').strip()])).strip()
|
137 |
+
query_reverse = ','.join(filter(None, [record.get('decimalLatitude', '').strip(),
|
138 |
+
record.get('decimalLongitude', '').strip()])).strip()
|
139 |
+
query_reverse_verbatim = record.get('verbatimCoordinates', '').strip()
|
140 |
+
|
141 |
+
|
142 |
+
'''
|
143 |
+
#For testing
|
144 |
+
# query_forward = 'Ann bor, michign'
|
145 |
+
query_forward = 'michigan'
|
146 |
+
query_forward_locality = 'Ann bor, michign'
|
147 |
+
# query_gps = "42 N,-83 W" # cannot have any spaces
|
148 |
+
# query_reverse_verbatim = "42.278366,-83.744718" # cannot have any spaces
|
149 |
+
query_reverse_verbatim = "42,-83" # cannot have any spaces
|
150 |
+
query_reverse = "42,-83" # cannot have any spaces
|
151 |
+
# params = {
|
152 |
+
# 'q': query_loc,
|
153 |
+
# 'apiKey': os.environ['here_api_key'],
|
154 |
+
# }'''
|
155 |
+
|
156 |
+
|
157 |
+
params_rev = {
|
158 |
+
'at': query_reverse,
|
159 |
+
'apiKey': os.environ['here_api_key'],
|
160 |
+
'lang': 'en',
|
161 |
+
}
|
162 |
+
params_reverse_verbatim = {
|
163 |
+
'at': query_reverse_verbatim,
|
164 |
+
'apiKey': os.environ['here_api_key'],
|
165 |
+
'lang': 'en',
|
166 |
+
}
|
167 |
+
params_forward = {
|
168 |
+
'q': query_forward,
|
169 |
+
'apiKey': os.environ['here_api_key'],
|
170 |
+
'lang': 'en',
|
171 |
+
}
|
172 |
+
params_forward_locality = {
|
173 |
+
'q': query_forward_locality,
|
174 |
+
'apiKey': os.environ['here_api_key'],
|
175 |
+
'lang': 'en',
|
176 |
+
}
|
177 |
+
|
178 |
+
### REVERSE
|
179 |
+
# If there are two string in the coordinates, try a reverse first based on the literal coordinates
|
180 |
+
response = requests.get(reverse_url, params=params_rev)
|
181 |
+
if response.status_code == 200:
|
182 |
+
data = response.json()
|
183 |
+
if data.get('items'):
|
184 |
+
first_result = data['items'][0]
|
185 |
+
GEO_dict_rev['GEO_method'] = 'HERE_Geocode_reverse'
|
186 |
+
GEO_dict_rev['GEO_formatted_full_string'] = first_result.get('title', '')
|
187 |
+
GEO_dict_rev['GEO_decimal_lat'] = first_result['position']['lat']
|
188 |
+
GEO_dict_rev['GEO_decimal_long'] = first_result['position']['lng']
|
189 |
+
|
190 |
+
address = first_result.get('address', {})
|
191 |
+
GEO_dict_rev['GEO_city'] = address.get('city', '')
|
192 |
+
GEO_dict_rev['GEO_county'] = address.get('county', '')
|
193 |
+
GEO_dict_rev['GEO_state'] = address.get('state', '')
|
194 |
+
GEO_dict_rev['GEO_state_code'] = address.get('stateCode', '')
|
195 |
+
GEO_dict_rev['GEO_country'] = address.get('countryName', '')
|
196 |
+
GEO_dict_rev['GEO_country_code'] = address.get('countryCode', '')
|
197 |
+
GEO_dict_rev['GEO_continent'] = get_continent(address.get('countryName', ''))
|
198 |
+
|
199 |
+
### REVERSE Verbatim
|
200 |
+
# If there are two string in the coordinates, try a reverse first based on the literal coordinates
|
201 |
+
if GEO_dict_rev['GEO_city']: # If the reverse was successful, pass
|
202 |
+
GEO_dict = GEO_dict_rev
|
203 |
+
else:
|
204 |
+
response = requests.get(reverse_url, params=params_reverse_verbatim)
|
205 |
+
if response.status_code == 200:
|
206 |
+
data = response.json()
|
207 |
+
if data.get('items'):
|
208 |
+
first_result = data['items'][0]
|
209 |
+
GEO_dict_rev_verbatim['GEO_method'] = 'HERE_Geocode_reverse_verbatimCoordinates'
|
210 |
+
GEO_dict_rev_verbatim['GEO_formatted_full_string'] = first_result.get('title', '')
|
211 |
+
GEO_dict_rev_verbatim['GEO_decimal_lat'] = first_result['position']['lat']
|
212 |
+
GEO_dict_rev_verbatim['GEO_decimal_long'] = first_result['position']['lng']
|
213 |
+
|
214 |
+
address = first_result.get('address', {})
|
215 |
+
GEO_dict_rev_verbatim['GEO_city'] = address.get('city', '')
|
216 |
+
GEO_dict_rev_verbatim['GEO_county'] = address.get('county', '')
|
217 |
+
GEO_dict_rev_verbatim['GEO_state'] = address.get('state', '')
|
218 |
+
GEO_dict_rev_verbatim['GEO_state_code'] = address.get('stateCode', '')
|
219 |
+
GEO_dict_rev_verbatim['GEO_country'] = address.get('countryName', '')
|
220 |
+
GEO_dict_rev_verbatim['GEO_country_code'] = address.get('countryCode', '')
|
221 |
+
GEO_dict_rev_verbatim['GEO_continent'] = get_continent(address.get('countryName', ''))
|
222 |
+
|
223 |
+
### FORWARD
|
224 |
+
### Try forward, if failes, try reverse using deci, then verbatim
|
225 |
+
if GEO_dict_rev['GEO_city']: # If the reverse was successful, pass
|
226 |
+
GEO_dict = GEO_dict_rev
|
227 |
+
elif GEO_dict_rev_verbatim['GEO_city']:
|
228 |
+
GEO_dict = GEO_dict_rev_verbatim
|
229 |
+
else:
|
230 |
+
response = requests.get(forward_url, params=params_forward)
|
231 |
+
if response.status_code == 200:
|
232 |
+
data = response.json()
|
233 |
+
if data.get('items'):
|
234 |
+
first_result = data['items'][0]
|
235 |
+
GEO_dict_forward['GEO_method'] = 'HERE_Geocode_forward'
|
236 |
+
GEO_dict_forward['GEO_formatted_full_string'] = first_result.get('title', '')
|
237 |
+
GEO_dict_forward['GEO_decimal_lat'] = first_result['position']['lat']
|
238 |
+
GEO_dict_forward['GEO_decimal_long'] = first_result['position']['lng']
|
239 |
+
|
240 |
+
address = first_result.get('address', {})
|
241 |
+
GEO_dict_forward['GEO_city'] = address.get('city', '')
|
242 |
+
GEO_dict_forward['GEO_county'] = address.get('county', '')
|
243 |
+
GEO_dict_forward['GEO_state'] = address.get('state', '')
|
244 |
+
GEO_dict_forward['GEO_state_code'] = address.get('stateCode', '')
|
245 |
+
GEO_dict_forward['GEO_country'] = address.get('countryName', '')
|
246 |
+
GEO_dict_forward['GEO_country_code'] = address.get('countryCode', '')
|
247 |
+
GEO_dict_forward['GEO_continent'] = get_continent(address.get('countryName', ''))
|
248 |
+
|
249 |
+
### FORWARD locality
|
250 |
+
### Try forward, if failes, try reverse using deci, then verbatim
|
251 |
+
if GEO_dict_rev['GEO_city']: # If the reverse was successful, pass
|
252 |
+
GEO_dict = GEO_dict_rev
|
253 |
+
elif GEO_dict_rev_verbatim['GEO_city']:
|
254 |
+
GEO_dict = GEO_dict_rev_verbatim
|
255 |
+
elif GEO_dict_forward['GEO_city']:
|
256 |
+
GEO_dict = GEO_dict_forward
|
257 |
+
else:
|
258 |
+
response = requests.get(forward_url, params=params_forward_locality)
|
259 |
+
if response.status_code == 200:
|
260 |
+
data = response.json()
|
261 |
+
if data.get('items'):
|
262 |
+
first_result = data['items'][0]
|
263 |
+
GEO_dict_forward_locality['GEO_method'] = 'HERE_Geocode_forward_locality'
|
264 |
+
GEO_dict_forward_locality['GEO_formatted_full_string'] = first_result.get('title', '')
|
265 |
+
GEO_dict_forward_locality['GEO_decimal_lat'] = first_result['position']['lat']
|
266 |
+
GEO_dict_forward_locality['GEO_decimal_long'] = first_result['position']['lng']
|
267 |
+
|
268 |
+
address = first_result.get('address', {})
|
269 |
+
GEO_dict_forward_locality['GEO_city'] = address.get('city', '')
|
270 |
+
GEO_dict_forward_locality['GEO_county'] = address.get('county', '')
|
271 |
+
GEO_dict_forward_locality['GEO_state'] = address.get('state', '')
|
272 |
+
GEO_dict_forward_locality['GEO_state_code'] = address.get('stateCode', '')
|
273 |
+
GEO_dict_forward_locality['GEO_country'] = address.get('countryName', '')
|
274 |
+
GEO_dict_forward_locality['GEO_country_code'] = address.get('countryCode', '')
|
275 |
+
GEO_dict_forward_locality['GEO_continent'] = get_continent(address.get('countryName', ''))
|
276 |
+
|
277 |
+
|
278 |
+
# print(json.dumps(GEO_dict,indent=4))
|
279 |
+
|
280 |
+
|
281 |
+
# Pick the most detailed version
|
282 |
+
# if GEO_dict_rev['GEO_formatted_full_string'] and GEO_dict_forward['GEO_formatted_full_string']:
|
283 |
+
for loc in pinpoint:
|
284 |
+
rev = GEO_dict_rev.get(loc,'')
|
285 |
+
forward = GEO_dict_forward.get(loc,'')
|
286 |
+
forward_locality = GEO_dict_forward_locality.get(loc,'')
|
287 |
+
rev_verbatim = GEO_dict_rev_verbatim.get(loc,'')
|
288 |
+
|
289 |
+
if not rev and not forward and not forward_locality and not rev_verbatim:
|
290 |
+
pass
|
291 |
+
elif rev:
|
292 |
+
GEO_dict = GEO_dict_rev
|
293 |
+
break
|
294 |
+
elif forward:
|
295 |
+
GEO_dict = GEO_dict_forward
|
296 |
+
break
|
297 |
+
elif forward_locality:
|
298 |
+
GEO_dict = GEO_dict_forward_locality
|
299 |
+
break
|
300 |
+
elif rev_verbatim:
|
301 |
+
GEO_dict = GEO_dict_rev_verbatim
|
302 |
+
break
|
303 |
+
else:
|
304 |
+
GEO_dict = GEO_dict_null
|
305 |
+
|
306 |
+
|
307 |
+
if GEO_dict['GEO_formatted_full_string'] and replace_if_success_geo:
|
308 |
+
GEO_dict['GEO_override_OCR'] = True
|
309 |
+
record['country'] = GEO_dict.get('GEO_country')
|
310 |
+
record['stateProvince'] = GEO_dict.get('GEO_state')
|
311 |
+
record['county'] = GEO_dict.get('GEO_county')
|
312 |
+
record['municipality'] = GEO_dict.get('GEO_city')
|
313 |
+
|
314 |
+
# print(json.dumps(GEO_dict,indent=4))
|
315 |
+
return record, GEO_dict
|
316 |
+
|
317 |
+
|
318 |
+
if __name__ == "__main__":
|
319 |
+
validate_coordinates_here(None)
|
vouchervision/utils_geolocate_OpenCage.py
ADDED
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from opencage.geocoder import OpenCageGeocode
|
3 |
+
import pycountry_convert as pc
|
4 |
+
import warnings
|
5 |
+
import unicodedata
|
6 |
+
import pycountry_convert as pc
|
7 |
+
import warnings
|
8 |
+
|
9 |
+
|
10 |
+
### TODO 1/24/24
|
11 |
+
### If I want to use this instead of HERE, update the procedure for picking the best/most granular geolocation
|
12 |
+
|
13 |
+
|
14 |
+
def normalize_country_name(name):
|
15 |
+
return unicodedata.normalize('NFKD', name).encode('ASCII', 'ignore').decode('ASCII')
|
16 |
+
|
17 |
+
def get_continent(country_name):
|
18 |
+
warnings.filterwarnings("ignore", category=UserWarning, module='pycountry')
|
19 |
+
|
20 |
+
continent_code_to_name = {
|
21 |
+
"AF": "Africa",
|
22 |
+
"NA": "North America",
|
23 |
+
"OC": "Oceania",
|
24 |
+
"AN": "Antarctica",
|
25 |
+
"AS": "Asia",
|
26 |
+
"EU": "Europe",
|
27 |
+
"SA": "South America"
|
28 |
+
}
|
29 |
+
|
30 |
+
try:
|
31 |
+
normalized_country_name = normalize_country_name(country_name)
|
32 |
+
# Get country alpha2 code
|
33 |
+
country_code = pc.country_name_to_country_alpha2(normalized_country_name)
|
34 |
+
# Get continent code from country alpha2 code
|
35 |
+
continent_code = pc.country_alpha2_to_continent_code(country_code)
|
36 |
+
# Map the continent code to continent name
|
37 |
+
return continent_code_to_name.get(continent_code, '')
|
38 |
+
except Exception as e:
|
39 |
+
print(str(e))
|
40 |
+
return ''
|
41 |
+
|
42 |
+
def validate_coordinates_opencage(record, replace_if_success_geo=False):
|
43 |
+
GEO_dict = {
|
44 |
+
'GEO_method': '',
|
45 |
+
'GEO_formatted_full_string': '',
|
46 |
+
'GEO_decimal_lat': '',
|
47 |
+
'GEO_decimal_long': '',
|
48 |
+
'GEO_city': '',
|
49 |
+
'GEO_county': '',
|
50 |
+
'GEO_state': '',
|
51 |
+
'GEO_state_code': '',
|
52 |
+
'GEO_country': '',
|
53 |
+
'GEO_country_code': '',
|
54 |
+
'GEO_continent': '',
|
55 |
+
}
|
56 |
+
|
57 |
+
geocoder = OpenCageGeocode(os.environ['open_cage_geocode'])
|
58 |
+
|
59 |
+
query_loc = ', '.join(filter(None, [record.get('municipality', '').strip(),
|
60 |
+
record.get('county', '').strip(),
|
61 |
+
record.get('stateProvince', '').strip(),
|
62 |
+
record.get('country', '').strip()])).strip()
|
63 |
+
|
64 |
+
|
65 |
+
query_decimal = ', '.join(filter(None, [record.get('decimalLatitude', '').strip(),
|
66 |
+
record.get('decimalLongitude', '').strip()])).strip()
|
67 |
+
query_verbatim = record.get('verbatimCoordinates', '').strip()
|
68 |
+
|
69 |
+
# results = geocoder.geocode('Ann Arbor, Michigan', no_annotations='1')
|
70 |
+
results = geocoder.geocode(query_loc, no_annotations='1')
|
71 |
+
|
72 |
+
if results:
|
73 |
+
GEO_dict['GEO_method'] = 'OpenCageGeocode_forward'
|
74 |
+
GEO_dict['GEO_formatted_full_string'] = results[0]['formatted']
|
75 |
+
GEO_dict['GEO_decimal_lat'] = results[0]['geometry']['lat']
|
76 |
+
GEO_dict['GEO_decimal_long'] = results[0]['geometry']['lng']
|
77 |
+
|
78 |
+
GEO_dict['GEO_city'] = results[0]['components']['city']
|
79 |
+
GEO_dict['GEO_county'] = results[0]['components']['county']
|
80 |
+
GEO_dict['GEO_state'] = results[0]['components']['state']
|
81 |
+
GEO_dict['GEO_state_code'] = results[0]['components']['state_code']
|
82 |
+
GEO_dict['GEO_country'] = results[0]['components']['country']
|
83 |
+
GEO_dict['GEO_country_code'] = results[0]['components']['country_code']
|
84 |
+
GEO_dict['GEO_continent'] = results[0]['components']['continent']
|
85 |
+
|
86 |
+
if GEO_dict['GEO_formatted_full_string'] and replace_if_success_geo:
|
87 |
+
GEO_dict['GEO_override_OCR'] = True
|
88 |
+
record['country'] = GEO_dict.get('GEO_country')
|
89 |
+
record['stateProvince'] = GEO_dict.get('GEO_state')
|
90 |
+
record['county'] = GEO_dict.get('GEO_county')
|
91 |
+
record['municipality'] = GEO_dict.get('GEO_city')
|
92 |
+
|
93 |
+
return record, GEO_dict
|
94 |
+
|
95 |
+
|
vouchervision/{utils.py → utils_hf.py}
RENAMED
File without changes
|
vouchervision/utils_taxonomy_WFO.py
ADDED
@@ -0,0 +1,319 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import requests
|
2 |
+
from urllib.parse import urlencode
|
3 |
+
from Levenshtein import ratio
|
4 |
+
from fuzzywuzzy import fuzz
|
5 |
+
|
6 |
+
class WFONameMatcher:
|
7 |
+
def __init__(self):
|
8 |
+
self.base_url = "https://list.worldfloraonline.org/matching_rest.php?"
|
9 |
+
self.N_BEST_CANDIDATES = 10
|
10 |
+
self.NULL_DICT = {
|
11 |
+
"WFO_exact_match": False,
|
12 |
+
"WFO_exact_match_name": "",
|
13 |
+
"WFO_candidate_names": "",
|
14 |
+
"WFO_best_match": "",
|
15 |
+
"WFO_placement": "",
|
16 |
+
"WFO_override_OCR": False,
|
17 |
+
}
|
18 |
+
self.SEP = '|'
|
19 |
+
|
20 |
+
def extract_input_string(self, record):
|
21 |
+
primary_input = f"{record.get('scientificName', '').strip()} {record.get('scientificNameAuthorship', '').strip()}".strip()
|
22 |
+
secondary_input = ' '.join(filter(None, [record.get('genus', '').strip(),
|
23 |
+
record.get('subgenus', '').strip(),
|
24 |
+
record.get('specificEpithet', '').strip(),
|
25 |
+
record.get('infraspecificEpithet', '').strip()])).strip()
|
26 |
+
|
27 |
+
return primary_input, secondary_input
|
28 |
+
|
29 |
+
def query_wfo_name_matching(self, input_string, check_homonyms=True, check_rank=True, accept_single_candidate=True):
|
30 |
+
params = {
|
31 |
+
"input_string": input_string,
|
32 |
+
"check_homonyms": check_homonyms,
|
33 |
+
"check_rank": check_rank,
|
34 |
+
"method": "full",
|
35 |
+
"accept_single_candidate": accept_single_candidate,
|
36 |
+
}
|
37 |
+
|
38 |
+
full_url = self.base_url + urlencode(params)
|
39 |
+
|
40 |
+
response = requests.get(full_url)
|
41 |
+
if response.status_code == 200:
|
42 |
+
return response.json()
|
43 |
+
else:
|
44 |
+
return {"error": True, "message": "Failed to fetch data from WFO API"}
|
45 |
+
|
46 |
+
def query_and_process(self, record):
|
47 |
+
primary_input, secondary_input = self.extract_input_string(record)
|
48 |
+
|
49 |
+
# Query with primary input
|
50 |
+
primary_result = self.query_wfo_name_matching(primary_input)
|
51 |
+
primary_processed, primary_ranked_candidates = self.process_wfo_response(primary_result, primary_input)
|
52 |
+
|
53 |
+
if primary_processed.get('WFO_exact_match'):
|
54 |
+
print("Selected Primary --- Exact Primary & Unchecked Secondary")
|
55 |
+
return primary_processed
|
56 |
+
else:
|
57 |
+
# Query with secondary input
|
58 |
+
secondary_result = self.query_wfo_name_matching(secondary_input)
|
59 |
+
secondary_processed, secondary_ranked_candidates = self.process_wfo_response(secondary_result, secondary_input)
|
60 |
+
|
61 |
+
if secondary_processed.get('WFO_exact_match'):
|
62 |
+
print("Selected Secondary --- Unchecked Primary & Exact Secondary")
|
63 |
+
return secondary_processed
|
64 |
+
|
65 |
+
else:
|
66 |
+
# Both failed, just return the first failure
|
67 |
+
if (primary_processed.get("WFO_candidate_names") == '') and (secondary_processed.get("WFO_candidate_names") == ''):
|
68 |
+
print("Selected Primary --- Failed Primary & Failed Secondary")
|
69 |
+
return primary_processed
|
70 |
+
|
71 |
+
# 1st failed, just return the second
|
72 |
+
elif (primary_processed.get("WFO_candidate_names") == '') and (len(secondary_processed.get("WFO_candidate_names")) > 0):
|
73 |
+
print("Selected Secondary --- Failed Primary & Partial Secondary")
|
74 |
+
return secondary_processed
|
75 |
+
|
76 |
+
# 2nd failed, just return the first
|
77 |
+
elif (len(primary_processed.get("WFO_candidate_names")) > 0) and (secondary_processed.get("WFO_candidate_names") == ''):
|
78 |
+
print("Selected Primary --- Partial Primary & Failed Secondary")
|
79 |
+
return primary_processed
|
80 |
+
|
81 |
+
# Both have partial matches, compare and rerank
|
82 |
+
elif (len(primary_processed.get("WFO_candidate_names")) > 0) and (len(secondary_processed.get("WFO_candidate_names")) > 0):
|
83 |
+
# Combine and sort results, ensuring no duplicates
|
84 |
+
combined_candidates = list(set(primary_ranked_candidates + secondary_ranked_candidates))
|
85 |
+
combined_candidates.sort(key=lambda x: (x[1], x[0]), reverse=True) # Sort by similarity score, then name
|
86 |
+
|
87 |
+
# Replace candidates with combined_candidates and combined best match
|
88 |
+
best_score_primary = primary_processed["WFO_candidate_names"][0][1]
|
89 |
+
best_score_secondary = secondary_processed["WFO_candidate_names"][0][1]
|
90 |
+
|
91 |
+
# Extracting only the candidate names from the top candidates
|
92 |
+
top_candidates = combined_candidates[:self.N_BEST_CANDIDATES]
|
93 |
+
cleaned_candidates = [cand[0] for cand in top_candidates]
|
94 |
+
|
95 |
+
if best_score_primary >= best_score_secondary:
|
96 |
+
|
97 |
+
primary_processed["WFO_candidate_names"] = cleaned_candidates
|
98 |
+
primary_processed["WFO_best_match"] = cleaned_candidates[0]
|
99 |
+
|
100 |
+
response_placement = self.query_wfo_name_matching(primary_processed["WFO_best_match"])
|
101 |
+
placement_exact_match = response_placement.get("match")
|
102 |
+
primary_processed["WFO_placement"] = placement_exact_match.get("placement", '')
|
103 |
+
|
104 |
+
print("Selected Primary --- Partial Primary & Partial Secondary")
|
105 |
+
return primary_processed
|
106 |
+
else:
|
107 |
+
secondary_processed["WFO_candidate_names"] = cleaned_candidates
|
108 |
+
secondary_processed["WFO_best_match"] = cleaned_candidates[0]
|
109 |
+
|
110 |
+
response_placement = self.query_wfo_name_matching(secondary_processed["WFO_best_match"])
|
111 |
+
placement_exact_match = response_placement.get("match")
|
112 |
+
secondary_processed["WFO_placement"] = placement_exact_match.get("placement", '')
|
113 |
+
|
114 |
+
print("Selected Secondary --- Partial Primary & Partial Secondary")
|
115 |
+
return secondary_processed
|
116 |
+
else:
|
117 |
+
return self.NULL_DICT
|
118 |
+
|
119 |
+
def process_wfo_response(self, response, query):
|
120 |
+
simplified_response = {}
|
121 |
+
ranked_candidates = None
|
122 |
+
|
123 |
+
exact_match = response.get("match")
|
124 |
+
simplified_response["WFO_exact_match"] = bool(exact_match)
|
125 |
+
|
126 |
+
candidates = response.get("candidates", [])
|
127 |
+
candidate_names = [candidate["full_name_plain"] for candidate in candidates] if candidates else []
|
128 |
+
|
129 |
+
if not exact_match and candidate_names:
|
130 |
+
cleaned_candidates, ranked_candidates = self._rank_candidates_by_similarity(query, candidate_names)
|
131 |
+
simplified_response["WFO_candidate_names"] = cleaned_candidates
|
132 |
+
simplified_response["WFO_best_match"] = cleaned_candidates[0] if cleaned_candidates else ''
|
133 |
+
elif exact_match:
|
134 |
+
simplified_response["WFO_candidate_names"] = exact_match.get("full_name_plain")
|
135 |
+
simplified_response["WFO_best_match"] = exact_match.get("full_name_plain")
|
136 |
+
else:
|
137 |
+
simplified_response["WFO_candidate_names"] = ''
|
138 |
+
simplified_response["WFO_best_match"] = ''
|
139 |
+
|
140 |
+
# Call WFO again to update placement using WFO_best_match
|
141 |
+
try:
|
142 |
+
response_placement = self.query_wfo_name_matching(simplified_response["WFO_best_match"])
|
143 |
+
placement_exact_match = response_placement.get("match")
|
144 |
+
simplified_response["WFO_placement"] = placement_exact_match.get("placement", '')
|
145 |
+
except:
|
146 |
+
simplified_response["WFO_placement"] = ''
|
147 |
+
|
148 |
+
return simplified_response, ranked_candidates
|
149 |
+
|
150 |
+
def _rank_candidates_by_similarity(self, query, candidates):
|
151 |
+
string_similarities = []
|
152 |
+
fuzzy_similarities = {candidate: fuzz.ratio(query, candidate) for candidate in candidates}
|
153 |
+
query_words = query.split()
|
154 |
+
|
155 |
+
for candidate in candidates:
|
156 |
+
candidate_words = candidate.split()
|
157 |
+
# Calculate word similarities and sum them up
|
158 |
+
word_similarities = [ratio(query_word, candidate_word) for query_word, candidate_word in zip(query_words, candidate_words)]
|
159 |
+
total_word_similarity = sum(word_similarities)
|
160 |
+
|
161 |
+
# Calculate combined similarity score (average of word and fuzzy similarities)
|
162 |
+
fuzzy_similarity = fuzzy_similarities[candidate]
|
163 |
+
combined_similarity = (total_word_similarity + fuzzy_similarity) / 2
|
164 |
+
string_similarities.append((candidate, combined_similarity))
|
165 |
+
|
166 |
+
# Sort the candidates based on combined similarity, higher scores first
|
167 |
+
ranked_candidates = sorted(string_similarities, key=lambda x: x[1], reverse=True)
|
168 |
+
|
169 |
+
# Extracting only the candidate names from the top candidates
|
170 |
+
top_candidates = ranked_candidates[:self.N_BEST_CANDIDATES]
|
171 |
+
cleaned_candidates = [cand[0] for cand in top_candidates]
|
172 |
+
|
173 |
+
return cleaned_candidates, ranked_candidates
|
174 |
+
|
175 |
+
def check_WFO(self, record, replace_if_success_wfo):
|
176 |
+
self.replace_if_success_wfo = replace_if_success_wfo
|
177 |
+
|
178 |
+
# "WFO_exact_match","WFO_exact_match_name","WFO_best_match","WFO_candidate_names","WFO_placement"
|
179 |
+
simplified_response = self.query_and_process(record)
|
180 |
+
simplified_response['WFO_override_OCR'] = False
|
181 |
+
|
182 |
+
# best_match
|
183 |
+
if simplified_response.get('WFO_exact_match'):
|
184 |
+
simplified_response['WFO_exact_match_name'] = simplified_response.get('WFO_best_match')
|
185 |
+
else:
|
186 |
+
simplified_response['WFO_exact_match_name'] = ''
|
187 |
+
|
188 |
+
# placement
|
189 |
+
wfo_placement = simplified_response.get('WFO_placement', '')
|
190 |
+
if wfo_placement:
|
191 |
+
parts = wfo_placement.split('/')[1:]
|
192 |
+
simplified_response['WFO_placement'] = self.SEP.join(parts)
|
193 |
+
else:
|
194 |
+
simplified_response['WFO_placement'] = ''
|
195 |
+
|
196 |
+
if simplified_response.get('WFO_exact_match') and replace_if_success_wfo:
|
197 |
+
simplified_response['WFO_override_OCR'] = True
|
198 |
+
name_parts = simplified_response.get('WFO_placement').split('$')[0]
|
199 |
+
name_parts = name_parts.split(self.SEP)
|
200 |
+
record['order'] = name_parts[3]
|
201 |
+
record['family'] = name_parts[4]
|
202 |
+
record['genus'] = name_parts[5]
|
203 |
+
record['specificEpithet'] = name_parts[6]
|
204 |
+
record['scientificName'] = simplified_response.get('WFO_exact_match_name')
|
205 |
+
|
206 |
+
return record, simplified_response
|
207 |
+
|
208 |
+
def validate_taxonomy_WFO(record_dict, replace_if_success_wfo=False):
|
209 |
+
Matcher = WFONameMatcher()
|
210 |
+
try:
|
211 |
+
record_dict, WFO_dict = Matcher.check_WFO(record_dict, replace_if_success_wfo)
|
212 |
+
return record_dict, WFO_dict
|
213 |
+
except:
|
214 |
+
return record_dict, Matcher.NULL_DICT
|
215 |
+
|
216 |
+
'''
|
217 |
+
if __name__ == "__main__":
|
218 |
+
Matcher = WFONameMatcher()
|
219 |
+
# input_string = "Rhopalocarpus alterfolius"
|
220 |
+
record_exact_match ={
|
221 |
+
"order": "Malpighiales",
|
222 |
+
"family": "Hypericaceae",
|
223 |
+
"scientificName": "Hypericum prolificum",
|
224 |
+
"scientificNameAuthorship": "",
|
225 |
+
|
226 |
+
"genus": "Hypericum",
|
227 |
+
"subgenus": "",
|
228 |
+
"specificEpithet": "prolificum",
|
229 |
+
"infraspecificEpithet": "",
|
230 |
+
}
|
231 |
+
record_partialPrimary_exactSecondary ={
|
232 |
+
"order": "Malpighiales",
|
233 |
+
"family": "Hypericaceae",
|
234 |
+
"scientificName": "Hyperic prolificum",
|
235 |
+
"scientificNameAuthorship": "",
|
236 |
+
|
237 |
+
"genus": "Hypericum",
|
238 |
+
"subgenus": "",
|
239 |
+
"specificEpithet": "prolificum",
|
240 |
+
"infraspecificEpithet": "",
|
241 |
+
}
|
242 |
+
record_exactPrimary_partialSecondary ={
|
243 |
+
"order": "Malpighiales",
|
244 |
+
"family": "Hypericaceae",
|
245 |
+
"scientificName": "Hypericum prolificum",
|
246 |
+
"scientificNameAuthorship": "",
|
247 |
+
|
248 |
+
"genus": "Hyperic",
|
249 |
+
"subgenus": "",
|
250 |
+
"specificEpithet": "prolificum",
|
251 |
+
"infraspecificEpithet": "",
|
252 |
+
}
|
253 |
+
record_partialPrimary_partialSecondary ={
|
254 |
+
"order": "Malpighiales",
|
255 |
+
"family": "Hypericaceae",
|
256 |
+
"scientificName": "Hyperic prolificum",
|
257 |
+
"scientificNameAuthorship": "",
|
258 |
+
|
259 |
+
"genus": "Hypericum",
|
260 |
+
"subgenus": "",
|
261 |
+
"specificEpithet": "prolific",
|
262 |
+
"infraspecificEpithet": "",
|
263 |
+
}
|
264 |
+
record_partialPrimary_partialSecondary_swap ={
|
265 |
+
"order": "Malpighiales",
|
266 |
+
"family": "Hypericaceae",
|
267 |
+
"scientificName": "Hypericum prolific",
|
268 |
+
"scientificNameAuthorship": "",
|
269 |
+
|
270 |
+
"genus": "Hyperic",
|
271 |
+
"subgenus": "",
|
272 |
+
"specificEpithet": "prolificum",
|
273 |
+
"infraspecificEpithet": "",
|
274 |
+
}
|
275 |
+
record_errorPrimary_partialSecondary ={
|
276 |
+
"order": "Malpighiales",
|
277 |
+
"family": "Hypericaceae",
|
278 |
+
"scientificName": "ricum proli",
|
279 |
+
"scientificNameAuthorship": "",
|
280 |
+
|
281 |
+
"genus": "Hyperic",
|
282 |
+
"subgenus": "",
|
283 |
+
"specificEpithet": "prolificum",
|
284 |
+
"infraspecificEpithet": "",
|
285 |
+
}
|
286 |
+
record_partialPrimary_errorSecondary ={
|
287 |
+
"order": "Malpighiales",
|
288 |
+
"family": "Hypericaceae",
|
289 |
+
"scientificName": "Hyperic prolificum",
|
290 |
+
"scientificNameAuthorship": "",
|
291 |
+
|
292 |
+
"genus": "ricum",
|
293 |
+
"subgenus": "",
|
294 |
+
"specificEpithet": "proli",
|
295 |
+
"infraspecificEpithet": "",
|
296 |
+
}
|
297 |
+
record_errorPrimary_errorSecondary ={
|
298 |
+
"order": "Malpighiales",
|
299 |
+
"family": "Hypericaceae",
|
300 |
+
"scientificName": "ricum proli",
|
301 |
+
"scientificNameAuthorship": "",
|
302 |
+
|
303 |
+
"genus": "ricum",
|
304 |
+
"subgenus": "",
|
305 |
+
"specificEpithet": "proli",
|
306 |
+
"infraspecificEpithet": "",
|
307 |
+
}
|
308 |
+
options = [record_exact_match,
|
309 |
+
record_partialPrimary_exactSecondary,
|
310 |
+
record_exactPrimary_partialSecondary,
|
311 |
+
record_partialPrimary_partialSecondary,
|
312 |
+
record_partialPrimary_partialSecondary_swap,
|
313 |
+
record_errorPrimary_partialSecondary,
|
314 |
+
record_partialPrimary_errorSecondary,
|
315 |
+
record_errorPrimary_errorSecondary]
|
316 |
+
for opt in options:
|
317 |
+
simplified_response = Matcher.check_WFO(opt)
|
318 |
+
print(json.dumps(simplified_response, indent=4))
|
319 |
+
'''
|
vouchervision/vouchervision_main.py
CHANGED
@@ -8,7 +8,7 @@ parentdir = os.path.dirname(currentdir)
|
|
8 |
sys.path.append(parentdir)
|
9 |
sys.path.append(currentdir)
|
10 |
from vouchervision.component_detector.component_detector import detect_plant_components, detect_archival_components
|
11 |
-
from general_utils import
|
12 |
from directory_structure_VV import Dir_Structure
|
13 |
from data_project import Project_Info
|
14 |
from LM2_logger import start_logging
|
@@ -16,10 +16,7 @@ from fetch_data import fetch_data
|
|
16 |
from utils_VoucherVision import VoucherVision, space_saver
|
17 |
|
18 |
|
19 |
-
def voucher_vision(cfg_file_path, dir_home, path_custom_prompts, cfg_test, progress_report, path_api_cost=None, test_ind = None, is_real_run=False):
|
20 |
-
# get_n_overall = progress_report.get_n_overall()
|
21 |
-
# progress_report.update_overall(f"Working on {test_ind+1} of {get_n_overall}")
|
22 |
-
|
23 |
t_overall = perf_counter()
|
24 |
|
25 |
# Load config file
|
@@ -29,19 +26,12 @@ def voucher_vision(cfg_file_path, dir_home, path_custom_prompts, cfg_test, progr
|
|
29 |
cfg = load_config_file(dir_home, cfg_file_path, system='VoucherVision') # For VoucherVision
|
30 |
else:
|
31 |
cfg = cfg_test
|
32 |
-
# user_cfg = load_config_file(dir_home, cfg_file_path)
|
33 |
-
# cfg = Config(user_cfg)
|
34 |
|
35 |
# Check to see if there are subdirs
|
36 |
# Yes --> use the names of the subsirs as run_name
|
37 |
run_name, dirs_list, has_subdirs = check_for_subdirs_VV(cfg)
|
38 |
print(f"run_name {run_name} dirs_list{dirs_list} has_subdirs{has_subdirs}")
|
39 |
|
40 |
-
# for dir_ind, dir_in in enumerate(dirs_list):
|
41 |
-
# if has_subdirs:
|
42 |
-
# cfg['leafmachine']['project']['dir_images_local'] = dir_in
|
43 |
-
# cfg['leafmachine']['project']['run_name'] = run_name[dir_ind]
|
44 |
-
|
45 |
# Dir structure
|
46 |
if is_real_run:
|
47 |
progress_report.update_overall(f"Creating Output Directory Structure")
|
@@ -67,21 +57,16 @@ def voucher_vision(cfg_file_path, dir_home, path_custom_prompts, cfg_test, progr
|
|
67 |
# Detect Archival Components
|
68 |
print_main_start("Locating Archival Components")
|
69 |
Project = detect_archival_components(cfg, logger, dir_home, Project, Dirs, is_real_run, progress_report)
|
70 |
-
|
71 |
# Save cropped detections
|
72 |
crop_detections_from_images_VV(cfg, logger, dir_home, Project, Dirs)
|
73 |
|
74 |
# Process labels
|
75 |
-
Voucher_Vision = VoucherVision(cfg, logger, dir_home, path_custom_prompts, Project, Dirs)
|
76 |
n_images = len(Voucher_Vision.img_paths)
|
77 |
-
last_JSON_response, total_tokens_in, total_tokens_out = Voucher_Vision.process_specimen_batch(progress_report, is_real_run)
|
78 |
-
|
79 |
-
|
80 |
-
cost_summary, data, total_cost = save_token_info_as_csv(Dirs, cfg['leafmachine']['LLM_version'], path_api_cost, total_tokens_in, total_tokens_out, n_images)
|
81 |
-
add_to_expense_report(dir_home, data)
|
82 |
-
logger.info(cost_summary)
|
83 |
-
else:
|
84 |
-
total_cost = None #TODO add config tests to expense_report
|
85 |
|
86 |
t_overall_s = perf_counter()
|
87 |
logger.name = 'Run Complete! :)'
|
@@ -89,20 +74,11 @@ def voucher_vision(cfg_file_path, dir_home, path_custom_prompts, cfg_test, progr
|
|
89 |
space_saver(cfg, Dirs, logger)
|
90 |
|
91 |
if is_real_run:
|
92 |
-
progress_report.update_overall(f"Run Complete!
|
93 |
-
|
94 |
-
for handler in logger.handlers[:]:
|
95 |
-
handler.close()
|
96 |
-
logger.removeHandler(handler)
|
97 |
|
98 |
-
|
99 |
-
dir_to_zip = os.path.join(Dirs.dir_home, Dirs.run_name)
|
100 |
-
zip_filename = Dirs.run_name
|
101 |
|
102 |
-
|
103 |
-
zip_filepath = make_zipfile(dir_to_zip, zip_filename)
|
104 |
-
|
105 |
-
return last_JSON_response, total_cost, zip_filepath
|
106 |
|
107 |
def voucher_vision_OCR_test(cfg_file_path, dir_home, cfg_test, path_to_crop):
|
108 |
# get_n_overall = progress_report.get_n_overall()
|
@@ -157,7 +133,6 @@ def voucher_vision_OCR_test(cfg_file_path, dir_home, cfg_test, path_to_crop):
|
|
157 |
Voucher_Vision = VoucherVision(cfg, logger, dir_home, None, Project, Dirs)
|
158 |
last_JSON_response = Voucher_Vision.process_specimen_batch_OCR_test(path_to_crop)
|
159 |
|
160 |
-
|
161 |
if __name__ == '__main__':
|
162 |
is_test = False
|
163 |
|
|
|
8 |
sys.path.append(parentdir)
|
9 |
sys.path.append(currentdir)
|
10 |
from vouchervision.component_detector.component_detector import detect_plant_components, detect_archival_components
|
11 |
+
from general_utils import add_to_expense_report, save_token_info_as_csv, print_main_start, check_for_subdirs_VV, load_config_file, load_config_file_testing, report_config, save_config_file, subset_dir_images, crop_detections_from_images_VV
|
12 |
from directory_structure_VV import Dir_Structure
|
13 |
from data_project import Project_Info
|
14 |
from LM2_logger import start_logging
|
|
|
16 |
from utils_VoucherVision import VoucherVision, space_saver
|
17 |
|
18 |
|
19 |
+
def voucher_vision(cfg_file_path, dir_home, path_custom_prompts, cfg_test, progress_report, json_report, path_api_cost=None, test_ind = None, is_hf = True, is_real_run=False):
|
|
|
|
|
|
|
20 |
t_overall = perf_counter()
|
21 |
|
22 |
# Load config file
|
|
|
26 |
cfg = load_config_file(dir_home, cfg_file_path, system='VoucherVision') # For VoucherVision
|
27 |
else:
|
28 |
cfg = cfg_test
|
|
|
|
|
29 |
|
30 |
# Check to see if there are subdirs
|
31 |
# Yes --> use the names of the subsirs as run_name
|
32 |
run_name, dirs_list, has_subdirs = check_for_subdirs_VV(cfg)
|
33 |
print(f"run_name {run_name} dirs_list{dirs_list} has_subdirs{has_subdirs}")
|
34 |
|
|
|
|
|
|
|
|
|
|
|
35 |
# Dir structure
|
36 |
if is_real_run:
|
37 |
progress_report.update_overall(f"Creating Output Directory Structure")
|
|
|
57 |
# Detect Archival Components
|
58 |
print_main_start("Locating Archival Components")
|
59 |
Project = detect_archival_components(cfg, logger, dir_home, Project, Dirs, is_real_run, progress_report)
|
60 |
+
|
61 |
# Save cropped detections
|
62 |
crop_detections_from_images_VV(cfg, logger, dir_home, Project, Dirs)
|
63 |
|
64 |
# Process labels
|
65 |
+
Voucher_Vision = VoucherVision(cfg, logger, dir_home, path_custom_prompts, Project, Dirs, is_hf)
|
66 |
n_images = len(Voucher_Vision.img_paths)
|
67 |
+
last_JSON_response, final_WFO_record, final_GEO_record, total_tokens_in, total_tokens_out = Voucher_Vision.process_specimen_batch(progress_report, json_report, is_real_run)
|
68 |
+
|
69 |
+
total_cost = save_token_info_as_csv(Dirs, cfg['leafmachine']['LLM_version'], path_api_cost, total_tokens_in, total_tokens_out, n_images, dir_home, logger)
|
|
|
|
|
|
|
|
|
|
|
70 |
|
71 |
t_overall_s = perf_counter()
|
72 |
logger.name = 'Run Complete! :)'
|
|
|
74 |
space_saver(cfg, Dirs, logger)
|
75 |
|
76 |
if is_real_run:
|
77 |
+
progress_report.update_overall(f"Run Complete!")
|
|
|
|
|
|
|
|
|
78 |
|
79 |
+
Voucher_Vision.close_logger_handlers()
|
|
|
|
|
80 |
|
81 |
+
return last_JSON_response, final_WFO_record, final_GEO_record, total_cost, Voucher_Vision.n_failed_OCR, Voucher_Vision.n_failed_LLM_calls
|
|
|
|
|
|
|
82 |
|
83 |
def voucher_vision_OCR_test(cfg_file_path, dir_home, cfg_test, path_to_crop):
|
84 |
# get_n_overall = progress_report.get_n_overall()
|
|
|
133 |
Voucher_Vision = VoucherVision(cfg, logger, dir_home, None, Project, Dirs)
|
134 |
last_JSON_response = Voucher_Vision.process_specimen_batch_OCR_test(path_to_crop)
|
135 |
|
|
|
136 |
if __name__ == '__main__':
|
137 |
is_test = False
|
138 |
|