Spaces:
Running
Running
phyloforfun
commited on
Commit
·
567930d
1
Parent(s):
c6a70af
updates
Browse files- .gitignore +15 -0
- .streamlit/config.toml +8 -2
- api_status.yaml +4 -2
- app.py +46 -25
- custom_prompts/SLTPvA_long.yaml +31 -36
- custom_prompts/SLTPvA_medium.yaml +2 -6
- custom_prompts/SLTPvA_short.yaml +2 -6
- custom_prompts/SLTPvB_long.yaml +107 -0
- custom_prompts/SLTPvB_medium.yaml +83 -0
- custom_prompts/SLTPvB_short.yaml +78 -0
- pages/prompt_builder.py +3 -1
- requirements.txt +0 -0
- requirements_conda.txt +0 -0
- requirements_with_versions.txt +0 -0
- run_VoucherVision.py +2 -1
- vouchervision/LLM_GoogleGemini.py +39 -19
- vouchervision/LLM_GooglePalm2.py +65 -20
- vouchervision/LLM_MistralAI.py +45 -19
- vouchervision/LLM_OpenAI.py +117 -37
- vouchervision/LLM_local_MistralAI.py +16 -9
- vouchervision/LLM_local_cpu_MistralAI.py +15 -8
- vouchervision/LM2_logger.py +22 -5
- vouchervision/OCR_google_cloud_vision.py +26 -6
- vouchervision/VoucherVision_Config_Builder.py +1 -1
- vouchervision/model_maps.py +42 -23
- vouchervision/prompt_catalog.py +28 -16
- vouchervision/tool_taxonomy_WFO.py +14 -5
- vouchervision/utils_LLM.py +1 -4
- vouchervision/utils_LLM_JSON_validation.py +2 -1
- vouchervision/utils_VoucherVision.py +23 -17
- vouchervision/utils_VoucherVision_parallel.py +1022 -0
- vouchervision/vouchervision_main.py +1 -0
- vouchervision/vouchervision_test_all_options_analysis.py +103 -0
- vouchervision/vouchervision_test_all_options_recipes.py +170 -0
.gitignore
CHANGED
@@ -7,15 +7,27 @@ yolov8x-pose.pt
|
|
7 |
yolov8n.pt
|
8 |
*PRIVATE_DATA*
|
9 |
|
|
|
|
|
10 |
# Prompts
|
11 |
/custom_prompts/*
|
12 |
!/custom_prompts/SLTPvA_long.yaml
|
13 |
!/custom_prompts/SLTPvA_medium.yaml
|
14 |
!/custom_prompts/SLTPvA_short.yaml
|
|
|
|
|
|
|
15 |
|
16 |
# Dirs
|
17 |
custom_prompts_deprecated/
|
18 |
demo/demo_output/*
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
19 |
demo/demo_configs/*
|
20 |
uploads/*
|
21 |
uploads_small/*
|
@@ -59,6 +71,9 @@ vouchervision/component_detector/runs/
|
|
59 |
vouchervision/component_detector/architecture/
|
60 |
vouchervision/component_detector/yolov5x6.pt
|
61 |
|
|
|
|
|
|
|
62 |
vouchervision/instructor-xl/
|
63 |
vouchervision/instructor-embedding/
|
64 |
|
|
|
7 |
yolov8n.pt
|
8 |
*PRIVATE_DATA*
|
9 |
|
10 |
+
vouchervision/LLM_MistralAI_Azure_endpoints.py
|
11 |
+
|
12 |
# Prompts
|
13 |
/custom_prompts/*
|
14 |
!/custom_prompts/SLTPvA_long.yaml
|
15 |
!/custom_prompts/SLTPvA_medium.yaml
|
16 |
!/custom_prompts/SLTPvA_short.yaml
|
17 |
+
!/custom_prompts/SLTPvB_long.yaml
|
18 |
+
!/custom_prompts/SLTPvB_medium.yaml
|
19 |
+
!/custom_prompts/SLTPvB_short.yaml
|
20 |
|
21 |
# Dirs
|
22 |
custom_prompts_deprecated/
|
23 |
demo/demo_output/*
|
24 |
+
|
25 |
+
demo/validation_images_repeat/
|
26 |
+
demo/validation_json/
|
27 |
+
demo/validation_figs/
|
28 |
+
demo/validation_output/
|
29 |
+
demo/validation_xlsx/
|
30 |
+
|
31 |
demo/demo_configs/*
|
32 |
uploads/*
|
33 |
uploads_small/*
|
|
|
71 |
vouchervision/component_detector/architecture/
|
72 |
vouchervision/component_detector/yolov5x6.pt
|
73 |
|
74 |
+
vouchervision/vouchervision_test_all_options.py
|
75 |
+
vouchervision/prompt_arena.py
|
76 |
+
|
77 |
vouchervision/instructor-xl/
|
78 |
vouchervision/instructor-embedding/
|
79 |
|
.streamlit/config.toml
CHANGED
@@ -1,5 +1,11 @@
|
|
1 |
[theme]
|
2 |
-
primaryColor = "#
|
3 |
backgroundColor="#1a1a1a"
|
4 |
secondaryBackgroundColor="#303030"
|
5 |
-
textColor = "cccccc"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
[theme]
|
2 |
+
primaryColor = "#16a616"
|
3 |
backgroundColor="#1a1a1a"
|
4 |
secondaryBackgroundColor="#303030"
|
5 |
+
textColor = "cccccc"
|
6 |
+
|
7 |
+
[server]
|
8 |
+
enableStaticServing = false
|
9 |
+
runOnSave = true
|
10 |
+
port = 8524
|
11 |
+
maxUploadSize = 5000
|
api_status.yaml
CHANGED
@@ -1,10 +1,12 @@
|
|
1 |
-
date:
|
2 |
missing_keys: []
|
3 |
present_keys:
|
4 |
-
- Google OCR (Valid)
|
|
|
5 |
- OpenAI (Valid)
|
6 |
- Azure OpenAI (Valid)
|
7 |
- Palm2 (Valid)
|
|
|
8 |
- Gemini (Valid)
|
9 |
- Mistral (Valid)
|
10 |
- HERE Geocode (Valid)
|
|
|
1 |
+
date: February 29, 2024
|
2 |
missing_keys: []
|
3 |
present_keys:
|
4 |
+
- Google OCR Print (Valid)
|
5 |
+
- Google OCR Handwriting (Valid)
|
6 |
- OpenAI (Valid)
|
7 |
- Azure OpenAI (Valid)
|
8 |
- Palm2 (Valid)
|
9 |
+
- Palm2 LangChain (Valid)
|
10 |
- Gemini (Valid)
|
11 |
- Mistral (Valid)
|
12 |
- HERE Geocode (Valid)
|
app.py
CHANGED
@@ -7,7 +7,6 @@ import pandas as pd
|
|
7 |
from io import BytesIO
|
8 |
from streamlit_extras.let_it_rain import rain
|
9 |
from annotated_text import annotated_text
|
10 |
-
from transformers import AutoConfig
|
11 |
|
12 |
from vouchervision.LeafMachine2_Config_Builder import write_config_file
|
13 |
from vouchervision.VoucherVision_Config_Builder import build_VV_config, TestOptionsGPT, TestOptionsPalm, check_if_usable
|
@@ -18,6 +17,7 @@ from vouchervision.API_validation import APIvalidation
|
|
18 |
from vouchervision.utils_hf import setup_streamlit_config, save_uploaded_file, save_uploaded_local, save_uploaded_file_local
|
19 |
from vouchervision.data_project import convert_pdf_to_jpg
|
20 |
from vouchervision.utils_LLM import check_system_gpus
|
|
|
21 |
|
22 |
import cProfile
|
23 |
import pstats
|
@@ -250,14 +250,25 @@ def load_gallery(converted_files, uploaded_file):
|
|
250 |
file_path_small = save_uploaded_file(st.session_state['dir_uploaded_images_small'], uploaded_file, img)
|
251 |
st.session_state['input_list_small'].append(file_path_small)
|
252 |
|
|
|
|
|
|
|
253 |
@st.cache_data
|
254 |
def handle_image_upload_and_gallery_hf(uploaded_files):
|
255 |
if uploaded_files:
|
|
|
256 |
# Clear input image gallery and input list
|
257 |
clear_image_uploads()
|
258 |
|
259 |
ind_small = 0
|
260 |
for uploaded_file in uploaded_files:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
261 |
# Determine the file type
|
262 |
if uploaded_file.name.lower().endswith('.pdf'):
|
263 |
# Handle PDF files
|
@@ -305,6 +316,8 @@ def handle_image_upload_and_gallery_hf(uploaded_files):
|
|
305 |
# If there are less than 100 images, take them all
|
306 |
images_to_display = st.session_state['input_list_small']
|
307 |
show_gallery_small_hf(images_to_display)
|
|
|
|
|
308 |
|
309 |
|
310 |
@st.cache_data
|
@@ -378,7 +391,7 @@ def content_input_images(col_left, col_right):
|
|
378 |
|
379 |
with col_right:
|
380 |
if st.session_state.is_hf:
|
381 |
-
handle_image_upload_and_gallery_hf(uploaded_files)
|
382 |
|
383 |
else:
|
384 |
st.session_state['view_local_gallery'] = st.toggle("View Image Gallery",)
|
@@ -427,7 +440,8 @@ def count_jpg_images(directory_path):
|
|
427 |
|
428 |
def create_download_button(zip_filepath, col, key):
|
429 |
with col:
|
430 |
-
labal_n_images = f"Download Results for {st.session_state['processing_add_on']} Images"
|
|
|
431 |
with open(zip_filepath, 'rb') as f:
|
432 |
bytes_io = BytesIO(f.read())
|
433 |
st.download_button(
|
@@ -1067,6 +1081,11 @@ def create_private_file():
|
|
1067 |
"client_x509_cert_url": "A LONG URL",
|
1068 |
"universe_domain": "googleapis.com"
|
1069 |
})
|
|
|
|
|
|
|
|
|
|
|
1070 |
google_application_credentials = st.text_input(label = 'Full path to Google Cloud JSON API key file', value = cfg_private['google'].get('GOOGLE_APPLICATION_CREDENTIALS', ''),
|
1071 |
placeholder = 'e.g. C:/Documents/Secret_Files/google_API/application_default_credentials.json',
|
1072 |
help ="This API Key is in the form of a JSON file. Please save the JSON file in a safe directory. DO NOT store the JSON key inside of the VoucherVision directory.",
|
@@ -1127,7 +1146,7 @@ def create_private_file():
|
|
1127 |
|
1128 |
st.write("---")
|
1129 |
st.subheader("MistralAI")
|
1130 |
-
st.markdown('Follow these [instructions](https://
|
1131 |
mistral_API_KEY = st.text_input("MistralAI API Key", cfg_private['mistral'].get('MISTRAL_API_KEY', ''),
|
1132 |
help='e.g. a 32-character string',
|
1133 |
placeholder='e.g. SATgthsykuE64FgrrrrEervr3S4455t_geyDeGq',
|
@@ -1360,7 +1379,7 @@ def get_all_cost_tables():
|
|
1360 |
cost_openai[key] = cost_data.get(value,'')
|
1361 |
elif 'PALM2' in parts or 'GEMINI' in parts:
|
1362 |
cost_google[key] = cost_data.get(value,'')
|
1363 |
-
elif 'MISTRAL' in parts:
|
1364 |
cost_mistral[key] = cost_data.get(value,'')
|
1365 |
|
1366 |
styled_cost_openai = convert_cost_dict_to_table(cost_openai, "OpenAI")
|
@@ -1403,9 +1422,9 @@ def content_header():
|
|
1403 |
N_STEPS = 6
|
1404 |
|
1405 |
if check_if_usable(is_hf=st.session_state['is_hf']):
|
1406 |
-
b_text = f"Start Processing {st.session_state['processing_add_on']} Images" if st.session_state['processing_add_on'] > 1 else f"Start Processing {st.session_state['processing_add_on']} Image"
|
1407 |
-
if st.session_state['processing_add_on'] == 0:
|
1408 |
-
|
1409 |
if st.button(b_text, type='primary',use_container_width=True):
|
1410 |
st.session_state['formatted_json'] = {}
|
1411 |
st.session_state['formatted_json_WFO'] = {}
|
@@ -1466,7 +1485,7 @@ def content_header():
|
|
1466 |
if st.session_state['zip_filepath']:
|
1467 |
create_download_button(st.session_state['zip_filepath'], col_run_1,key=97863332)
|
1468 |
else:
|
1469 |
-
st.button("Start
|
1470 |
with col_run_4:
|
1471 |
st.error(":heavy_exclamation_mark: Required API keys not set. Please visit the 'API Keys' tab and set the Google Vision OCR API key and at least one LLM key.")
|
1472 |
|
@@ -1482,11 +1501,11 @@ def content_header():
|
|
1482 |
ct_left, ct_right = st.columns([1,1])
|
1483 |
with ct_left:
|
1484 |
st.button("Refresh", on_click=refresh, use_container_width=True)
|
1485 |
-
|
1486 |
-
|
1487 |
-
|
1488 |
-
|
1489 |
-
|
1490 |
|
1491 |
|
1492 |
|
@@ -1687,12 +1706,12 @@ def content_prompt_and_llm_version():
|
|
1687 |
selected_version = default_version
|
1688 |
st.session_state.config['leafmachine']['project']['prompt_version'] = st.selectbox("Prompt Version", available_prompts, index=available_prompts.index(selected_version),label_visibility='collapsed')
|
1689 |
|
1690 |
-
|
1691 |
-
|
1692 |
-
|
1693 |
-
|
1694 |
-
|
1695 |
-
|
1696 |
|
1697 |
|
1698 |
st.header('LLM Version')
|
@@ -1703,18 +1722,18 @@ def content_prompt_and_llm_version():
|
|
1703 |
st.session_state.config['leafmachine']['LLM_version'] = st.selectbox("LLM version", GUI_MODEL_LIST, index=GUI_MODEL_LIST.index(st.session_state.config['leafmachine'].get('LLM_version', ModelMaps.MODELS_GUI_DEFAULT)))
|
1704 |
st.markdown("""
|
1705 |
Based on preliminary results, the following models perform the best. We are currently running tests of all possible OCR + LLM + Prompt combinations to create recipes for different workflows.
|
1706 |
-
- `Mistral
|
1707 |
-
- `Mistral Small`
|
1708 |
-
- `Mistral Tiny`
|
1709 |
- `PaLM 2 text-bison@001`
|
1710 |
- `GPT 4 Turbo 1106-preview`
|
1711 |
-
- `GPT 3.5
|
1712 |
- `LOCAL Mixtral 7Bx8 Instruct`
|
1713 |
- `LOCAL Mixtral 7B Instruct`
|
1714 |
|
1715 |
Larger models (e.g., `GPT 4`, `GPT 4 32k`, `Gemini Pro`) do not necessarily perform better for these tasks. MistralAI models exceeded our expectations and perform extremely well. PaLM 2 text-bison@001 also seems to consistently out-perform Gemini Pro.
|
1716 |
|
1717 |
-
The `SLTPvA_short.yaml` prompt also seems to work better with smaller LLMs (e.g., Mistral Tiny). Alternatively, enable double OCR to help the LLM focus on the OCR text given a longer prompt.
|
|
|
|
|
1718 |
|
1719 |
|
1720 |
def content_api_check():
|
@@ -1927,6 +1946,8 @@ def content_ocr_method():
|
|
1927 |
# st.text_area(label='Handwritten/Printed + trOCR',placeholder=demo_text_trh,disabled=True, label_visibility='visible', height=150)
|
1928 |
|
1929 |
def is_valid_huggingface_model_path(model_path):
|
|
|
|
|
1930 |
try:
|
1931 |
# Attempt to load the model configuration from Hugging Face Model Hub
|
1932 |
config = AutoConfig.from_pretrained(model_path)
|
|
|
7 |
from io import BytesIO
|
8 |
from streamlit_extras.let_it_rain import rain
|
9 |
from annotated_text import annotated_text
|
|
|
10 |
|
11 |
from vouchervision.LeafMachine2_Config_Builder import write_config_file
|
12 |
from vouchervision.VoucherVision_Config_Builder import build_VV_config, TestOptionsGPT, TestOptionsPalm, check_if_usable
|
|
|
17 |
from vouchervision.utils_hf import setup_streamlit_config, save_uploaded_file, save_uploaded_local, save_uploaded_file_local
|
18 |
from vouchervision.data_project import convert_pdf_to_jpg
|
19 |
from vouchervision.utils_LLM import check_system_gpus
|
20 |
+
from vouchervision.OCR_google_cloud_vision import check_for_inappropriate_content
|
21 |
|
22 |
import cProfile
|
23 |
import pstats
|
|
|
250 |
file_path_small = save_uploaded_file(st.session_state['dir_uploaded_images_small'], uploaded_file, img)
|
251 |
st.session_state['input_list_small'].append(file_path_small)
|
252 |
|
253 |
+
|
254 |
+
|
255 |
+
|
256 |
@st.cache_data
|
257 |
def handle_image_upload_and_gallery_hf(uploaded_files):
|
258 |
if uploaded_files:
|
259 |
+
|
260 |
# Clear input image gallery and input list
|
261 |
clear_image_uploads()
|
262 |
|
263 |
ind_small = 0
|
264 |
for uploaded_file in uploaded_files:
|
265 |
+
|
266 |
+
if check_for_inappropriate_content(uploaded_file):
|
267 |
+
clear_image_uploads()
|
268 |
+
st.error("Warning: You have uploaded an inappropriate image")
|
269 |
+
return True
|
270 |
+
|
271 |
+
|
272 |
# Determine the file type
|
273 |
if uploaded_file.name.lower().endswith('.pdf'):
|
274 |
# Handle PDF files
|
|
|
316 |
# If there are less than 100 images, take them all
|
317 |
images_to_display = st.session_state['input_list_small']
|
318 |
show_gallery_small_hf(images_to_display)
|
319 |
+
|
320 |
+
return False
|
321 |
|
322 |
|
323 |
@st.cache_data
|
|
|
391 |
|
392 |
with col_right:
|
393 |
if st.session_state.is_hf:
|
394 |
+
result = handle_image_upload_and_gallery_hf(uploaded_files)
|
395 |
|
396 |
else:
|
397 |
st.session_state['view_local_gallery'] = st.toggle("View Image Gallery",)
|
|
|
440 |
|
441 |
def create_download_button(zip_filepath, col, key):
|
442 |
with col:
|
443 |
+
# labal_n_images = f"Download Results for {st.session_state['processing_add_on']} Images"
|
444 |
+
labal_n_images = f"Download Results"
|
445 |
with open(zip_filepath, 'rb') as f:
|
446 |
bytes_io = BytesIO(f.read())
|
447 |
st.download_button(
|
|
|
1081 |
"client_x509_cert_url": "A LONG URL",
|
1082 |
"universe_domain": "googleapis.com"
|
1083 |
})
|
1084 |
+
|
1085 |
+
blog_text('Google project ID', ': The project ID is the "project_id" value from the JSON file.')
|
1086 |
+
blog_text('Google project location', ': The project location specifies the location of the Google server that your project resources will utilize. It should not really make a difference which location you choose. We use `us-central1`, but you might want to choose a location closer to where you live. [please see this page for a list of available regions](https://cloud.google.com/vertex-ai/docs/general/locations)')
|
1087 |
+
|
1088 |
+
|
1089 |
google_application_credentials = st.text_input(label = 'Full path to Google Cloud JSON API key file', value = cfg_private['google'].get('GOOGLE_APPLICATION_CREDENTIALS', ''),
|
1090 |
placeholder = 'e.g. C:/Documents/Secret_Files/google_API/application_default_credentials.json',
|
1091 |
help ="This API Key is in the form of a JSON file. Please save the JSON file in a safe directory. DO NOT store the JSON key inside of the VoucherVision directory.",
|
|
|
1146 |
|
1147 |
st.write("---")
|
1148 |
st.subheader("MistralAI")
|
1149 |
+
st.markdown('Follow these [instructions](https://console.mistral.ai/) to generate an API key for MistralAI.')
|
1150 |
mistral_API_KEY = st.text_input("MistralAI API Key", cfg_private['mistral'].get('MISTRAL_API_KEY', ''),
|
1151 |
help='e.g. a 32-character string',
|
1152 |
placeholder='e.g. SATgthsykuE64FgrrrrEervr3S4455t_geyDeGq',
|
|
|
1379 |
cost_openai[key] = cost_data.get(value,'')
|
1380 |
elif 'PALM2' in parts or 'GEMINI' in parts:
|
1381 |
cost_google[key] = cost_data.get(value,'')
|
1382 |
+
elif ('MISTRAL' in parts) or ('MIXTRAL' in parts):
|
1383 |
cost_mistral[key] = cost_data.get(value,'')
|
1384 |
|
1385 |
styled_cost_openai = convert_cost_dict_to_table(cost_openai, "OpenAI")
|
|
|
1422 |
N_STEPS = 6
|
1423 |
|
1424 |
if check_if_usable(is_hf=st.session_state['is_hf']):
|
1425 |
+
# b_text = f"Start Processing {st.session_state['processing_add_on']} Images" if st.session_state['processing_add_on'] > 1 else f"Start Processing {st.session_state['processing_add_on']} Image"
|
1426 |
+
# if st.session_state['processing_add_on'] == 0:
|
1427 |
+
b_text = f"Start Transcription"
|
1428 |
if st.button(b_text, type='primary',use_container_width=True):
|
1429 |
st.session_state['formatted_json'] = {}
|
1430 |
st.session_state['formatted_json_WFO'] = {}
|
|
|
1485 |
if st.session_state['zip_filepath']:
|
1486 |
create_download_button(st.session_state['zip_filepath'], col_run_1,key=97863332)
|
1487 |
else:
|
1488 |
+
st.button("Start Transcription", type='primary', disabled=True)
|
1489 |
with col_run_4:
|
1490 |
st.error(":heavy_exclamation_mark: Required API keys not set. Please visit the 'API Keys' tab and set the Google Vision OCR API key and at least one LLM key.")
|
1491 |
|
|
|
1501 |
ct_left, ct_right = st.columns([1,1])
|
1502 |
with ct_left:
|
1503 |
st.button("Refresh", on_click=refresh, use_container_width=True)
|
1504 |
+
with ct_right:
|
1505 |
+
try:
|
1506 |
+
st.page_link(os.path.join("pages","faqs.py"), label="FAQs", icon="❔")
|
1507 |
+
except:
|
1508 |
+
st.page_link(os.path.join(os.path.dirname(__file__),"pages","faqs.py"), label="FAQs", icon="❔")
|
1509 |
|
1510 |
|
1511 |
|
|
|
1706 |
selected_version = default_version
|
1707 |
st.session_state.config['leafmachine']['project']['prompt_version'] = st.selectbox("Prompt Version", available_prompts, index=available_prompts.index(selected_version),label_visibility='collapsed')
|
1708 |
|
1709 |
+
with col_prompt_2:
|
1710 |
+
# if st.button("Build Custom LLM Prompt"):
|
1711 |
+
try:
|
1712 |
+
st.page_link(os.path.join("pages","prompt_builder.py"), label="Prompt Builder", icon="🚧")
|
1713 |
+
except:
|
1714 |
+
st.page_link(os.path.join(os.path.dirname(__file__),"pages","prompt_builder.py"), label="Prompt Builder", icon="🚧")
|
1715 |
|
1716 |
|
1717 |
st.header('LLM Version')
|
|
|
1722 |
st.session_state.config['leafmachine']['LLM_version'] = st.selectbox("LLM version", GUI_MODEL_LIST, index=GUI_MODEL_LIST.index(st.session_state.config['leafmachine'].get('LLM_version', ModelMaps.MODELS_GUI_DEFAULT)))
|
1723 |
st.markdown("""
|
1724 |
Based on preliminary results, the following models perform the best. We are currently running tests of all possible OCR + LLM + Prompt combinations to create recipes for different workflows.
|
1725 |
+
- Any Mistral model e.g., `Mistral Large`
|
|
|
|
|
1726 |
- `PaLM 2 text-bison@001`
|
1727 |
- `GPT 4 Turbo 1106-preview`
|
1728 |
+
- `GPT 3.5 Turbo`
|
1729 |
- `LOCAL Mixtral 7Bx8 Instruct`
|
1730 |
- `LOCAL Mixtral 7B Instruct`
|
1731 |
|
1732 |
Larger models (e.g., `GPT 4`, `GPT 4 32k`, `Gemini Pro`) do not necessarily perform better for these tasks. MistralAI models exceeded our expectations and perform extremely well. PaLM 2 text-bison@001 also seems to consistently out-perform Gemini Pro.
|
1733 |
|
1734 |
+
The `SLTPvA_short.yaml` prompt also seems to work better with smaller LLMs (e.g., Mistral Tiny). Alternatively, enable double OCR to help the LLM focus on the OCR text given a longer prompt.
|
1735 |
+
|
1736 |
+
Models `GPT 3.5 Turbo` and `GPT 4 Turbo 0125-preview` enable OpenAI's [JSON mode](https://platform.openai.com/docs/guides/text-generation/json-mode), which helps prevent JSON errors. All models implement Langchain JSON parsing too, so JSON errors are rare for most models.""")
|
1737 |
|
1738 |
|
1739 |
def content_api_check():
|
|
|
1946 |
# st.text_area(label='Handwritten/Printed + trOCR',placeholder=demo_text_trh,disabled=True, label_visibility='visible', height=150)
|
1947 |
|
1948 |
def is_valid_huggingface_model_path(model_path):
|
1949 |
+
from transformers import AutoConfig
|
1950 |
+
|
1951 |
try:
|
1952 |
# Attempt to load the model configuration from Hugging Face Model Hub
|
1953 |
config = AutoConfig.from_pretrained(model_path)
|
custom_prompts/SLTPvA_long.yaml
CHANGED
@@ -28,10 +28,7 @@ rules:
|
|
28 |
scientificNameAuthorship: The authorship information for the scientificName formatted according to the conventions of the applicable Darwin Core nomenclaturalCode.
|
29 |
genus: Taxonomic determination to genus. Genus must be capitalized. If
|
30 |
genus is not present use the taxonomic family name followed by the word 'indet'.
|
31 |
-
subgenus: The full scientific name of the subgenus in which the taxon is classified.
|
32 |
-
Values should include the genus to avoid homonym confusion.
|
33 |
specificEpithet: The name of the first or species epithet of the scientificName. Only include the species epithet.
|
34 |
-
infraspecificEpithet: The name of the lowest or terminal infraspecific epithet of the scientificName, excluding any rank designation.
|
35 |
identifiedBy: A comma separated list of names of people, groups, or organizations who assigned the taxon to the subject organism. This is not the specimen collector.
|
36 |
recordedBy: A comma separated list of names of people, groups, or organizations responsible for observing, recording, collecting, or presenting the original specimen.
|
37 |
The primary collector or observer should be listed first.
|
@@ -63,7 +60,7 @@ rules:
|
|
63 |
the exact origin or location of the specimen.
|
64 |
degreeOfEstablishment: Cultivated plants are intentionally grown by humans. In text descriptions,
|
65 |
look for planting dates, garden locations, ornamental, cultivar names, garden,
|
66 |
-
or farm to indicate cultivated plant.
|
67 |
decimalLatitude: Latitude decimal coordinate. Correct and convert the verbatim location coordinates to conform
|
68 |
with the decimal degrees GPS coordinate format.
|
69 |
decimalLongitude: Longitude decimal coordinate. Correct and convert the verbatim location coordinates to conform
|
@@ -78,35 +75,33 @@ rules:
|
|
78 |
are explicit then convert from feet ("ft" or "ft." or "feet") to meters ("m"
|
79 |
or "m." or "meters"). Round to integer.
|
80 |
mapping:
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
- occurrenceRemarks
|
112 |
-
MISC:
|
|
|
28 |
scientificNameAuthorship: The authorship information for the scientificName formatted according to the conventions of the applicable Darwin Core nomenclaturalCode.
|
29 |
genus: Taxonomic determination to genus. Genus must be capitalized. If
|
30 |
genus is not present use the taxonomic family name followed by the word 'indet'.
|
|
|
|
|
31 |
specificEpithet: The name of the first or species epithet of the scientificName. Only include the species epithet.
|
|
|
32 |
identifiedBy: A comma separated list of names of people, groups, or organizations who assigned the taxon to the subject organism. This is not the specimen collector.
|
33 |
recordedBy: A comma separated list of names of people, groups, or organizations responsible for observing, recording, collecting, or presenting the original specimen.
|
34 |
The primary collector or observer should be listed first.
|
|
|
60 |
the exact origin or location of the specimen.
|
61 |
degreeOfEstablishment: Cultivated plants are intentionally grown by humans. In text descriptions,
|
62 |
look for planting dates, garden locations, ornamental, cultivar names, garden,
|
63 |
+
or farm to indicate cultivated plant. Set to 'cultivated' if cultivated, otherwise use an empty string.
|
64 |
decimalLatitude: Latitude decimal coordinate. Correct and convert the verbatim location coordinates to conform
|
65 |
with the decimal degrees GPS coordinate format.
|
66 |
decimalLongitude: Longitude decimal coordinate. Correct and convert the verbatim location coordinates to conform
|
|
|
75 |
are explicit then convert from feet ("ft" or "ft." or "feet") to meters ("m"
|
76 |
or "m." or "meters"). Round to integer.
|
77 |
mapping:
|
78 |
+
TAXONOMY:
|
79 |
+
- catalogNumber
|
80 |
+
- order
|
81 |
+
- family
|
82 |
+
- scientificName
|
83 |
+
- scientificNameAuthorship
|
84 |
+
- genus
|
85 |
+
- specificEpithet
|
86 |
+
GEOGRAPHY:
|
87 |
+
- country
|
88 |
+
- stateProvince
|
89 |
+
- county
|
90 |
+
- municipality
|
91 |
+
- decimalLatitude
|
92 |
+
- decimalLongitude
|
93 |
+
- verbatimCoordinates
|
94 |
+
LOCALITY:
|
95 |
+
- locality
|
96 |
+
- habitat
|
97 |
+
- minimumElevationInMeters
|
98 |
+
- maximumElevationInMeters
|
99 |
+
COLLECTING:
|
100 |
+
- identifiedBy
|
101 |
+
- recordedBy
|
102 |
+
- recordNumber
|
103 |
+
- verbatimEventDate
|
104 |
+
- eventDate
|
105 |
+
- degreeOfEstablishment
|
106 |
+
- occurrenceRemarks
|
107 |
+
MISC: []
|
|
|
|
custom_prompts/SLTPvA_medium.yaml
CHANGED
@@ -27,9 +27,7 @@ rules:
|
|
27 |
and any lower classifications.
|
28 |
scientificNameAuthorship: The authorship information for the scientificName formatted according to the conventions of the applicable Darwin Core nomenclaturalCode.
|
29 |
genus: Taxonomic determination to genus. Genus must be capitalized.
|
30 |
-
subgenus: The full scientific name of the subgenus in which the taxon is classified.
|
31 |
specificEpithet: The name of the first or species epithet of the scientificName. Only include the species epithet.
|
32 |
-
infraspecificEpithet: The name of the lowest or terminal infraspecific epithet of the scientificName, excluding any rank designation.
|
33 |
identifiedBy: A comma separated list of names of people, groups, or organizations who assigned the taxon to the subject organism. This is not the specimen collector.
|
34 |
recordedBy: A comma separated list of names of people, groups, or organizations
|
35 |
recordNumber: An identifier given to the specimen at the time it was recorded.
|
@@ -46,7 +44,7 @@ rules:
|
|
46 |
the exact origin or location of the specimen.
|
47 |
degreeOfEstablishment: Cultivated plants are intentionally grown by humans. In text descriptions,
|
48 |
look for planting dates, garden locations, ornamental, cultivar names, garden,
|
49 |
-
or farm to indicate cultivated plant.
|
50 |
decimalLatitude: Latitude decimal coordinate. Correct and convert the verbatim location coordinates to conform with the decimal degrees GPS coordinate format.
|
51 |
decimalLongitude: Longitude decimal coordinate. Correct and convert the verbatim location coordinates to conform with the decimal degrees GPS coordinate format.
|
52 |
verbatimCoordinates: Verbatim location coordinates as they appear on the label.
|
@@ -60,9 +58,7 @@ mapping:
|
|
60 |
- scientificName
|
61 |
- scientificNameAuthorship
|
62 |
- genus
|
63 |
-
- subgenus
|
64 |
- specificEpithet
|
65 |
-
- infraspecificEpithet
|
66 |
GEOGRAPHY:
|
67 |
- country
|
68 |
- stateProvince
|
@@ -84,4 +80,4 @@ mapping:
|
|
84 |
- eventDate
|
85 |
- degreeOfEstablishment
|
86 |
- occurrenceRemarks
|
87 |
-
MISC:
|
|
|
27 |
and any lower classifications.
|
28 |
scientificNameAuthorship: The authorship information for the scientificName formatted according to the conventions of the applicable Darwin Core nomenclaturalCode.
|
29 |
genus: Taxonomic determination to genus. Genus must be capitalized.
|
|
|
30 |
specificEpithet: The name of the first or species epithet of the scientificName. Only include the species epithet.
|
|
|
31 |
identifiedBy: A comma separated list of names of people, groups, or organizations who assigned the taxon to the subject organism. This is not the specimen collector.
|
32 |
recordedBy: A comma separated list of names of people, groups, or organizations
|
33 |
recordNumber: An identifier given to the specimen at the time it was recorded.
|
|
|
44 |
the exact origin or location of the specimen.
|
45 |
degreeOfEstablishment: Cultivated plants are intentionally grown by humans. In text descriptions,
|
46 |
look for planting dates, garden locations, ornamental, cultivar names, garden,
|
47 |
+
or farm to indicate cultivated plant. Set to 'cultivated' if cultivated, otherwise use an empty string.
|
48 |
decimalLatitude: Latitude decimal coordinate. Correct and convert the verbatim location coordinates to conform with the decimal degrees GPS coordinate format.
|
49 |
decimalLongitude: Longitude decimal coordinate. Correct and convert the verbatim location coordinates to conform with the decimal degrees GPS coordinate format.
|
50 |
verbatimCoordinates: Verbatim location coordinates as they appear on the label.
|
|
|
58 |
- scientificName
|
59 |
- scientificNameAuthorship
|
60 |
- genus
|
|
|
61 |
- specificEpithet
|
|
|
62 |
GEOGRAPHY:
|
63 |
- country
|
64 |
- stateProvince
|
|
|
80 |
- eventDate
|
81 |
- degreeOfEstablishment
|
82 |
- occurrenceRemarks
|
83 |
+
MISC: []
|
custom_prompts/SLTPvA_short.yaml
CHANGED
@@ -26,9 +26,7 @@ rules:
|
|
26 |
scientificName: scientific name of the taxon including Genus, specific epithet, and any lower classifications.
|
27 |
scientificNameAuthorship: authorship information for the scientificName formatted according to the conventions of the applicable Darwin Core nomenclaturalCode.
|
28 |
genus: taxonomic determination to Genus, Genus must be capitalized.
|
29 |
-
subgenus: name of the subgenus.
|
30 |
specificEpithet: The name of the first or species epithet of the scientificName. Only include the species epithet.
|
31 |
-
infraspecificEpithet: lowest or terminal infraspecific epithet of the scientificName.
|
32 |
identifiedBy: list of names of people, doctors, professors, groups, or organizations who identified, determined the taxon name to the subject organism. This is not the specimen collector.
|
33 |
recordedBy: list of names of people, doctors, professors, groups, or organizations.
|
34 |
recordNumber: identifier given to the specimen at the time it was recorded.
|
@@ -41,7 +39,7 @@ rules:
|
|
41 |
county: county, shire, department, parish etc.
|
42 |
municipality: city, municipality, etc.
|
43 |
locality: description of geographic information aiding in pinpointing the exact origin or location of the specimen.
|
44 |
-
degreeOfEstablishment: cultivated plants are intentionally grown by humans.
|
45 |
decimalLatitude: latitude decimal coordinate.
|
46 |
decimalLongitude: longitude decimal coordinate.
|
47 |
verbatimCoordinates: verbatim location coordinates.
|
@@ -55,9 +53,7 @@ mapping:
|
|
55 |
- scientificName
|
56 |
- scientificNameAuthorship
|
57 |
- genus
|
58 |
-
- subgenus
|
59 |
- specificEpithet
|
60 |
-
- infraspecificEpithet
|
61 |
GEOGRAPHY:
|
62 |
- country
|
63 |
- stateProvince
|
@@ -79,4 +75,4 @@ mapping:
|
|
79 |
- eventDate
|
80 |
- degreeOfEstablishment
|
81 |
- occurrenceRemarks
|
82 |
-
MISC:
|
|
|
26 |
scientificName: scientific name of the taxon including Genus, specific epithet, and any lower classifications.
|
27 |
scientificNameAuthorship: authorship information for the scientificName formatted according to the conventions of the applicable Darwin Core nomenclaturalCode.
|
28 |
genus: taxonomic determination to Genus, Genus must be capitalized.
|
|
|
29 |
specificEpithet: The name of the first or species epithet of the scientificName. Only include the species epithet.
|
|
|
30 |
identifiedBy: list of names of people, doctors, professors, groups, or organizations who identified, determined the taxon name to the subject organism. This is not the specimen collector.
|
31 |
recordedBy: list of names of people, doctors, professors, groups, or organizations.
|
32 |
recordNumber: identifier given to the specimen at the time it was recorded.
|
|
|
39 |
county: county, shire, department, parish etc.
|
40 |
municipality: city, municipality, etc.
|
41 |
locality: description of geographic information aiding in pinpointing the exact origin or location of the specimen.
|
42 |
+
degreeOfEstablishment: cultivated plants are intentionally grown by humans. Set to 'cultivated' if cultivated, otherwise use an empty string.
|
43 |
decimalLatitude: latitude decimal coordinate.
|
44 |
decimalLongitude: longitude decimal coordinate.
|
45 |
verbatimCoordinates: verbatim location coordinates.
|
|
|
53 |
- scientificName
|
54 |
- scientificNameAuthorship
|
55 |
- genus
|
|
|
56 |
- specificEpithet
|
|
|
57 |
GEOGRAPHY:
|
58 |
- country
|
59 |
- stateProvince
|
|
|
75 |
- eventDate
|
76 |
- degreeOfEstablishment
|
77 |
- occurrenceRemarks
|
78 |
+
MISC: []
|
custom_prompts/SLTPvB_long.yaml
ADDED
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
prompt_author: Will Weaver
|
2 |
+
prompt_author_institution: University of Michigan
|
3 |
+
prompt_name: SLTPvB_long
|
4 |
+
prompt_version: v-1-0
|
5 |
+
prompt_description: Prompt developed by the University of Michigan.
|
6 |
+
SLTPvB prompts all have standardized column headers (fields) that were chosen due to their reliability and prevalence in herbarium records.
|
7 |
+
All field descriptions are based on the official Darwin Core guidelines.
|
8 |
+
SLTPvB_long - The most verbose prompt option. Descriptions closely follow DwC guides. Detailed rules for the LLM to follow. Works best with double or triple OCR to increase attention back to the OCR (select 'use both OCR models' or 'handwritten + printed' along with trOCR).
|
9 |
+
SLTPvB_medium - Shorter verion of _long.
|
10 |
+
SLTPvB_short - The least verbose possible prompt while still providing rules and DwC descriptions.
|
11 |
+
LLM: General Purpose
|
12 |
+
instructions: 1. Refactor the unstructured OCR text into a dictionary based on the JSON structure outlined below.
|
13 |
+
2. Map the unstructured OCR text to the appropriate JSON key and populate the field given the user-defined rules.
|
14 |
+
3. JSON key values are permitted to remain empty strings if the corresponding information is not found in the unstructured OCR text.
|
15 |
+
4. Duplicate dictionary fields are not allowed.
|
16 |
+
5. Ensure all JSON keys are in camel case.
|
17 |
+
6. Ensure new JSON field values follow sentence case capitalization.
|
18 |
+
7. Ensure all key-value pairs in the JSON dictionary strictly adhere to the format and data types specified in the template.
|
19 |
+
8. Ensure output JSON string is valid JSON format. It should not have trailing commas or unquoted keys.
|
20 |
+
9. Only return a JSON dictionary represented as a string. You should not explain your answer.
|
21 |
+
json_formatting_instructions: This section provides rules for formatting each JSON value organized by the JSON key.
|
22 |
+
rules:
|
23 |
+
catalogNumber: Barcode identifier, typically a number with at least 6 digits, but fewer than 30 digits.
|
24 |
+
order: The full scientific name of the order in which the taxon is classified. Order must be capitalized.
|
25 |
+
family: The full scientific name of the family in which the taxon is classified. Family must be capitalized.
|
26 |
+
speciesBinomialName: The scientific name of the taxon including genus, specific epithet,
|
27 |
+
and any lower classifications.
|
28 |
+
genus: Taxonomic determination to genus. Genus must be capitalized. If
|
29 |
+
genus is not present use the taxonomic family name followed by the word 'indet'.
|
30 |
+
specificEpithet: The name of the first or species epithet of the scientificName. Only include the species epithet.
|
31 |
+
speciesBinomialNameAuthorship: The authorship information for the scientificName formatted according to the conventions of the applicable Darwin Core nomenclaturalCode.
|
32 |
+
collector: A comma separated list of names of people, groups, or organizations responsible for observing, recording, collecting, or presenting the original specimen.
|
33 |
+
The primary collector or observer should be listed first.
|
34 |
+
recordNumber: An identifier given to the occurrence at the time it was recorded.
|
35 |
+
Often serves as a link between field notes and an occurrence record, such as a specimen collector's number.
|
36 |
+
identifiedBy: A comma separated list of names of people, groups, or organizations who assigned the taxon to the subject organism. This is not the specimen collector.
|
37 |
+
verbatimCollectionDate: The verbatim original representation of the date and time information for when the specimen was collected.
|
38 |
+
Date of collection exactly as it appears on the label. Do not change
|
39 |
+
the format or correct typos.
|
40 |
+
collectionDate: Date the specimen was collected formatted as year-month-day, YYYY-MM_DD. If
|
41 |
+
specific components of the date are unknown, they should be replaced with
|
42 |
+
zeros. Examples "0000-00-00" if the entire date is unknown, "YYYY-00-00"
|
43 |
+
if only the year is known, and "YYYY-MM-00" if year and month are known
|
44 |
+
but day is not.
|
45 |
+
occurrenceRemarks: Text describing the specimen's geographic location. Text describing the appearance of the specimen.
|
46 |
+
A statement about the presence or absence of a taxon at a the collection location.
|
47 |
+
Text describing the significance of the specimen, such as a specific expedition or notable collection.
|
48 |
+
Description of plant features such as leaf shape, size, color,
|
49 |
+
stem texture, height, flower structure, scent, fruit or seed characteristics,
|
50 |
+
root system type, overall growth habit and form, any notable aroma or secretions,
|
51 |
+
presence of hairs or bristles, and any other distinguishing morphological
|
52 |
+
or physiological characteristics.
|
53 |
+
habitat: A category or description of the habitat in which the specimen collection event occurred.
|
54 |
+
locality: Description of geographic location, landscape, landmarks, regional
|
55 |
+
features, nearby places, or any contextual information aiding in pinpointing
|
56 |
+
the exact origin or location of the specimen.
|
57 |
+
isCultivated: Cultivated plants are intentionally grown by humans. In text descriptions,
|
58 |
+
look for planting dates, garden locations, ornamental, cultivar names, garden,
|
59 |
+
or farm to indicate cultivated plant. Set to 'cultivated' if cultivated, otherwise use an empty string.
|
60 |
+
country: The name of the country or major administrative unit in which the specimen was originally collected.
|
61 |
+
stateProvince: The name of the next smaller administrative region than country (state, province, canton, department, region, etc.) in which the specimen was originally collected.
|
62 |
+
county: The full, unabbreviated name of the next smaller administrative region than stateProvince (county, shire, department, parish etc.) in which the specimen was originally collected.
|
63 |
+
municipality: The full, unabbreviated name of the next smaller administrative region than county (city, municipality, etc.) in which the specimen was originally collected.
|
64 |
+
verbatimCoordinates: Verbatim location coordinates as they appear on the label. Do not
|
65 |
+
convert formats. Possible coordinate types include [Lat, Long, UTM, TRS].
|
66 |
+
decimalLatitude: Latitude decimal coordinate. Correct and convert the verbatim location coordinates to conform
|
67 |
+
with the decimal degrees GPS coordinate format.
|
68 |
+
decimalLongitude: Longitude decimal coordinate. Correct and convert the verbatim location coordinates to conform
|
69 |
+
with the decimal degrees GPS coordinate format.
|
70 |
+
minimumElevationInMeters: Minimum elevation or altitude in meters. Only if units are explicit
|
71 |
+
then convert from feet ("ft" or "ft."" or "feet") to meters ("m" or "m." or
|
72 |
+
"meters"). Round to integer.
|
73 |
+
maximumElevationInMeters: Maximum elevation or altitude in meters. If only one elevation
|
74 |
+
is present, then max_elevation should be set to the null_value. Only if units
|
75 |
+
are explicit then convert from feet ("ft" or "ft." or "feet") to meters ("m"
|
76 |
+
or "m." or "meters"). Round to integer.
|
77 |
+
mapping:
|
78 |
+
TAXONOMY:
|
79 |
+
- catalogNumber
|
80 |
+
- order
|
81 |
+
- family
|
82 |
+
- speciesBinomialName
|
83 |
+
- genus
|
84 |
+
- specificEpithet
|
85 |
+
- speciesBinomialNameAuthorship
|
86 |
+
GEOGRAPHY:
|
87 |
+
- country
|
88 |
+
- stateProvince
|
89 |
+
- county
|
90 |
+
- municipality
|
91 |
+
- verbatimCoordinates
|
92 |
+
- decimalLatitude
|
93 |
+
- decimalLongitude
|
94 |
+
- minimumElevationInMeters
|
95 |
+
- maximumElevationInMeters
|
96 |
+
LOCALITY:
|
97 |
+
- occurrenceRemarks
|
98 |
+
- habitat
|
99 |
+
- locality
|
100 |
+
- isCultivated
|
101 |
+
COLLECTING:
|
102 |
+
- collector
|
103 |
+
- recordNumber
|
104 |
+
- identifiedBy
|
105 |
+
- verbatimCollectionDate
|
106 |
+
- collectionDate
|
107 |
+
MISC: []
|
custom_prompts/SLTPvB_medium.yaml
ADDED
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
prompt_author: Will Weaver
|
2 |
+
prompt_author_institution: University of Michigan
|
3 |
+
prompt_name: SLTPvB_medium
|
4 |
+
prompt_version: v-1-0
|
5 |
+
prompt_description: Prompt developed by the University of Michigan.
|
6 |
+
SLTPvB prompts all have standardized column headers (fields) that were chosen due to their reliability and prevalence in herbarium records.
|
7 |
+
All field descriptions are based on the official Darwin Core guidelines.
|
8 |
+
SLTPvB_long - The most verbose prompt option. Descriptions closely follow DwC guides. Detailed rules for the LLM to follow. Works best with double or triple OCR to increase attention back to the OCR (select 'use both OCR models' or 'handwritten + printed' along with trOCR).
|
9 |
+
SLTPvB_medium - Shorter verion of _long.
|
10 |
+
SLTPvB_short - The least verbose possible prompt while still providing rules and DwC descriptions.
|
11 |
+
LLM: General Purpose
|
12 |
+
instructions: 1. Refactor the unstructured OCR text into a dictionary based on the JSON structure outlined below.
|
13 |
+
2. Map the unstructured OCR text to the appropriate JSON key and populate the field given the user-defined rules.
|
14 |
+
3. JSON key values are permitted to remain empty strings if the corresponding information is not found in the unstructured OCR text.
|
15 |
+
4. Duplicate dictionary fields are not allowed.
|
16 |
+
5. Ensure all JSON keys are in camel case.
|
17 |
+
6. Ensure new JSON field values follow sentence case capitalization.
|
18 |
+
7. Ensure all key-value pairs in the JSON dictionary strictly adhere to the format and data types specified in the template.
|
19 |
+
8. Ensure output JSON string is valid JSON format. It should not have trailing commas or unquoted keys.
|
20 |
+
9. Only return a JSON dictionary represented as a string. You should not explain your answer.
|
21 |
+
json_formatting_instructions: This section provides rules for formatting each JSON value organized by the JSON key.
|
22 |
+
rules:
|
23 |
+
catalogNumber: Barcode identifier, typically a number with at least 6 digits, but fewer than 30 digits.
|
24 |
+
order: The full scientific name of the order in which the taxon is classified. Order must be capitalized.
|
25 |
+
family: The full scientific name of the family in which the taxon is classified. Family must be capitalized.
|
26 |
+
speciesBinomialName: The scientific name of the taxon including genus, specific epithet,
|
27 |
+
and any lower classifications.
|
28 |
+
genus: Taxonomic determination to genus. Genus must be capitalized.
|
29 |
+
specificEpithet: The name of the first or species epithet of the scientificName. Only include the species epithet.
|
30 |
+
speciesBinomialNameAuthorship: The authorship information for the scientificName formatted according to the conventions of the applicable Darwin Core nomenclaturalCode.
|
31 |
+
collector: A comma separated list of names of people, groups, or organizations
|
32 |
+
recordNumber: An identifier given to the specimen at the time it was recorded.
|
33 |
+
identifiedBy: A comma separated list of names of people, groups, or organizations who assigned the taxon to the subject organism. This is not the specimen collector.
|
34 |
+
verbatimCollectionDate: The verbatim original representation of the date and time information for when the specimen was collected.
|
35 |
+
collectionDate: Date the specimen was collected formatted as year-month-day YYYY-MM-DD.
|
36 |
+
occurrenceRemarks: Text describing the specimen's geographic location, appearance of the specimen, presence or absence of a taxon at a the collection location, the significance of the specimen, such as a specific expedition or notable collection, plant features and descriptions.
|
37 |
+
habitat: A category or description of the habitat in which the specimen collection event occurred.
|
38 |
+
locality: Description of geographic location, landscape, landmarks, regional
|
39 |
+
features, nearby places, or any contextual information aiding in pinpointing
|
40 |
+
the exact origin or location of the specimen.
|
41 |
+
isCultivated: Cultivated plants are intentionally grown by humans. In text descriptions,
|
42 |
+
look for planting dates, garden locations, ornamental, cultivar names, garden,
|
43 |
+
or farm to indicate cultivated plant. Set to 'cultivated' if cultivated, otherwise use an empty string.
|
44 |
+
country: The name of the country or major administrative unit in which the specimen was originally collected.
|
45 |
+
stateProvince: The name of the next smaller administrative region than country (state, province, canton, department, region, etc.) in which the specimen was originally collected.
|
46 |
+
county: The full, unabbreviated name of the next smaller administrative region than stateProvince (county, shire, department, parish etc.) in which the specimen was originally collected.
|
47 |
+
municipality: The full, unabbreviated name of the next smaller administrative region than county (city, municipality, etc.) in which the specimen was originally collected.
|
48 |
+
verbatimCoordinates: Verbatim location coordinates as they appear on the label.
|
49 |
+
decimalLatitude: Latitude decimal coordinate. Correct and convert the verbatim location coordinates to conform with the decimal degrees GPS coordinate format.
|
50 |
+
decimalLongitude: Longitude decimal coordinate. Correct and convert the verbatim location coordinates to conform with the decimal degrees GPS coordinate format.
|
51 |
+
minimumElevationInMeters: Minimum elevation or altitude in meters. Only if units are explicit then convert from feet ("ft" or "ft."" or "feet") to meters ("m" or "m." or "meters"). Round to integer.
|
52 |
+
maximumElevationInMeters: Maximum elevation or altitude in meters. If only one elevation is present, then max_elevation should be set to the null_value. Only if units are explicit then convert from feet ("ft" or "ft." or "feet") to meters ("m" or "m." or "meters"). Round to integer.
|
53 |
+
mapping:
|
54 |
+
TAXONOMY:
|
55 |
+
- catalogNumber
|
56 |
+
- order
|
57 |
+
- family
|
58 |
+
- speciesBinomialName
|
59 |
+
- genus
|
60 |
+
- specificEpithet
|
61 |
+
- speciesBinomialNameAuthorship
|
62 |
+
GEOGRAPHY:
|
63 |
+
- country
|
64 |
+
- stateProvince
|
65 |
+
- county
|
66 |
+
- municipality
|
67 |
+
- verbatimCoordinates
|
68 |
+
- decimalLatitude
|
69 |
+
- decimalLongitude
|
70 |
+
- minimumElevationInMeters
|
71 |
+
- maximumElevationInMeters
|
72 |
+
LOCALITY:
|
73 |
+
- occurrenceRemarks
|
74 |
+
- habitat
|
75 |
+
- locality
|
76 |
+
- isCultivated
|
77 |
+
COLLECTING:
|
78 |
+
- collector
|
79 |
+
- recordNumber
|
80 |
+
- identifiedBy
|
81 |
+
- verbatimCollectionDate
|
82 |
+
- collectionDate
|
83 |
+
MISC: []
|
custom_prompts/SLTPvB_short.yaml
ADDED
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
prompt_author: Will Weaver
|
2 |
+
prompt_author_institution: University of Michigan
|
3 |
+
prompt_name: SLTPvB_short
|
4 |
+
prompt_version: v-1-0
|
5 |
+
prompt_description: Prompt developed by the University of Michigan.
|
6 |
+
SLTPvB prompts all have standardized column headers (fields) that were chosen due to their reliability and prevalence in herbarium records.
|
7 |
+
All field descriptions are based on the official Darwin Core guidelines.
|
8 |
+
SLTPvB_long - The most verbose prompt option. Descriptions closely follow DwC guides. Detailed rules for the LLM to follow. Works best with double or triple OCR to increase attention back to the OCR (select 'use both OCR models' or 'handwritten + printed' along with trOCR).
|
9 |
+
SLTPvB_medium - Shorter verion of _long.
|
10 |
+
SLTPvB_short - The least verbose possible prompt while still providing rules and DwC descriptions.
|
11 |
+
LLM: General Purpose
|
12 |
+
instructions: 1. Refactor the unstructured OCR text into a dictionary based on the JSON structure outlined below.
|
13 |
+
2. Map the unstructured OCR text to the appropriate JSON key and populate the field given the user-defined rules.
|
14 |
+
3. JSON key values are permitted to remain empty strings if the corresponding information is not found in the unstructured OCR text.
|
15 |
+
4. Duplicate dictionary fields are not allowed.
|
16 |
+
5. Ensure all JSON keys are in camel case.
|
17 |
+
6. Ensure new JSON field values follow sentence case capitalization.
|
18 |
+
7. Ensure all key-value pairs in the JSON dictionary strictly adhere to the format and data types specified in the template.
|
19 |
+
8. Ensure output JSON string is valid JSON format. It should not have trailing commas or unquoted keys.
|
20 |
+
9. Only return a JSON dictionary represented as a string. You should not explain your answer.
|
21 |
+
json_formatting_instructions: This section provides rules for formatting each JSON value organized by the JSON key.
|
22 |
+
rules:
|
23 |
+
catalogNumber: barcode identifier, at least 6 digits, fewer than 30 digits.
|
24 |
+
order: full scientific name of the Order in which the taxon is classified. Order must be capitalized.
|
25 |
+
family: full scientific name of the Family in which the taxon is classified. Family must be capitalized.
|
26 |
+
speciesBinomialName: scientific name of the taxon including Genus, specific epithet, and any lower classifications.
|
27 |
+
genus: taxonomic determination to Genus, Genus must be capitalized.
|
28 |
+
specificEpithet: The name of the first or species epithet of the scientificBinomial. Only include the species epithet.
|
29 |
+
speciesBinomialNameAuthorship: authorship information for the scientificName formatted according to the conventions of the applicable Darwin Core nomenclaturalCode.
|
30 |
+
collector: list of names of people, doctors, professors, groups, or organizations.
|
31 |
+
recordNumber: identifier given to the specimen at the time it was recorded.
|
32 |
+
identifiedBy: list of names of people, doctors, professors, groups, or organizations who identified, determined the taxon name to the subject organism. This is not the specimen collector.
|
33 |
+
verbatimCollectionDate: The verbatim original representation of the date and time information for when the specimen was collected.
|
34 |
+
collectionDate: collection date formatted as year-month-day YYYY-MM-DD.
|
35 |
+
occurrenceRemarks: all descriptive text in the OCR rearranged into sensible sentences or sentence fragments.
|
36 |
+
habitat: habitat description.
|
37 |
+
locality: description of geographic information aiding in pinpointing the exact origin or location of the specimen.
|
38 |
+
isCultivated: cultivated plants are intentionally grown by humans. Set to 'cultivated' if cultivated, otherwise use an empty string.
|
39 |
+
country: country or major administrative unit.
|
40 |
+
stateProvince: state, province, canton, department, region, etc.
|
41 |
+
county: county, shire, department, parish etc.
|
42 |
+
municipality: city, municipality, etc.
|
43 |
+
verbatimCoordinates: verbatim location coordinates.
|
44 |
+
decimalLatitude: latitude decimal coordinate.
|
45 |
+
decimalLongitude: longitude decimal coordinate.
|
46 |
+
minimumElevationInMeters: minimum elevation or altitude in meters.
|
47 |
+
maximumElevationInMeters: maximum elevation or altitude in meters.
|
48 |
+
mapping:
|
49 |
+
TAXONOMY:
|
50 |
+
- catalogNumber
|
51 |
+
- order
|
52 |
+
- family
|
53 |
+
- speciesBinomialName
|
54 |
+
- genus
|
55 |
+
- specificEpithet
|
56 |
+
- speciesBinomialNameAuthorship
|
57 |
+
GEOGRAPHY:
|
58 |
+
- country
|
59 |
+
- stateProvince
|
60 |
+
- county
|
61 |
+
- municipality
|
62 |
+
- verbatimCoordinates
|
63 |
+
- decimalLatitude
|
64 |
+
- decimalLongitude
|
65 |
+
- minimumElevationInMeters
|
66 |
+
- maximumElevationInMeters
|
67 |
+
LOCALITY:
|
68 |
+
- occurrenceRemarks
|
69 |
+
- habitat
|
70 |
+
- locality
|
71 |
+
- isCultivated
|
72 |
+
COLLECTING:
|
73 |
+
- collector
|
74 |
+
- recordNumber
|
75 |
+
- identifiedBy
|
76 |
+
- verbatimCollectionDate
|
77 |
+
- collectionDate
|
78 |
+
MISC: []
|
pages/prompt_builder.py
CHANGED
@@ -76,7 +76,9 @@ def load_prompt_yaml(filename):
|
|
76 |
st.session_state['mapping'] = st.session_state['prompt_info'].get('mapping', {})
|
77 |
st.session_state['LLM'] = st.session_state['prompt_info'].get('LLM', 'General Purpose')
|
78 |
|
79 |
-
#
|
|
|
|
|
80 |
st.session_state['assigned_columns'] = list(chain.from_iterable(st.session_state['mapping'].values()))
|
81 |
|
82 |
|
|
|
76 |
st.session_state['mapping'] = st.session_state['prompt_info'].get('mapping', {})
|
77 |
st.session_state['LLM'] = st.session_state['prompt_info'].get('LLM', 'General Purpose')
|
78 |
|
79 |
+
# print(st.session_state['mapping'].values())
|
80 |
+
# print(chain.from_iterable(st.session_state['mapping'].values()))
|
81 |
+
# print(list(chain.from_iterable(st.session_state['mapping'].values())))
|
82 |
st.session_state['assigned_columns'] = list(chain.from_iterable(st.session_state['mapping'].values()))
|
83 |
|
84 |
|
requirements.txt
CHANGED
Binary files a/requirements.txt and b/requirements.txt differ
|
|
requirements_conda.txt
ADDED
Binary file (1.97 kB). View file
|
|
requirements_with_versions.txt
ADDED
Binary file (11.1 kB). View file
|
|
run_VoucherVision.py
CHANGED
@@ -31,7 +31,7 @@ def resolve_path(path):
|
|
31 |
if __name__ == "__main__":
|
32 |
dir_home = os.path.dirname(__file__)
|
33 |
|
34 |
-
start_port =
|
35 |
try:
|
36 |
free_port = find_available_port(start_port)
|
37 |
sys.argv = [
|
@@ -42,6 +42,7 @@ if __name__ == "__main__":
|
|
42 |
"--global.developmentMode=false",
|
43 |
# "--server.port=8545",
|
44 |
f"--server.port={free_port}",
|
|
|
45 |
# Toggle below for HF vs Local
|
46 |
# "--is_hf=1",
|
47 |
# "--is_hf=0",
|
|
|
31 |
if __name__ == "__main__":
|
32 |
dir_home = os.path.dirname(__file__)
|
33 |
|
34 |
+
start_port = 8530
|
35 |
try:
|
36 |
free_port = find_available_port(start_port)
|
37 |
sys.argv = [
|
|
|
42 |
"--global.developmentMode=false",
|
43 |
# "--server.port=8545",
|
44 |
f"--server.port={free_port}",
|
45 |
+
f"--server.maxUploadSize=51200",
|
46 |
# Toggle below for HF vs Local
|
47 |
# "--is_hf=1",
|
48 |
# "--is_hf=0",
|
vouchervision/LLM_GoogleGemini.py
CHANGED
@@ -20,7 +20,7 @@ class GoogleGeminiHandler:
|
|
20 |
VENDOR = 'google'
|
21 |
STARTING_TEMP = 0.5
|
22 |
|
23 |
-
def __init__(self, cfg, logger, model_name, JSON_dict_structure):
|
24 |
self.cfg = cfg
|
25 |
self.tool_WFO = self.cfg['leafmachine']['project']['tool_WFO']
|
26 |
self.tool_GEO = self.cfg['leafmachine']['project']['tool_GEO']
|
@@ -30,10 +30,8 @@ class GoogleGeminiHandler:
|
|
30 |
self.model_name = model_name
|
31 |
self.JSON_dict_structure = JSON_dict_structure
|
32 |
|
33 |
-
self.
|
34 |
-
|
35 |
-
self.adjust_temp = self.starting_temp
|
36 |
-
|
37 |
self.monitor = SystemLoadMonitor(logger)
|
38 |
|
39 |
self.parser = JsonOutputParser()
|
@@ -50,11 +48,24 @@ class GoogleGeminiHandler:
|
|
50 |
def _set_config(self):
|
51 |
# os.environ['GOOGLE_API_KEY'] # Must be set too for the retry call, set in VoucherVision class along with other API Keys
|
52 |
# vertexai.init(project=os.environ['PALM_PROJECT_ID'], location=os.environ['PALM_LOCATION'])
|
53 |
-
self.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
54 |
"max_output_tokens": 1024,
|
55 |
"temperature": self.starting_temp,
|
56 |
-
"top_p": 1
|
57 |
}
|
|
|
|
|
|
|
|
|
58 |
self.safety_settings = {
|
59 |
HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE,
|
60 |
HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_NONE,
|
@@ -65,22 +76,26 @@ class GoogleGeminiHandler:
|
|
65 |
|
66 |
def _adjust_config(self):
|
67 |
new_temp = self.adjust_temp + self.temp_increment
|
68 |
-
self.json_report
|
|
|
69 |
self.logger.info(f'Incrementing temperature from {self.adjust_temp} to {new_temp}')
|
70 |
self.adjust_temp += self.temp_increment
|
71 |
self.config['temperature'] = self.adjust_temp
|
72 |
|
73 |
def _reset_config(self):
|
74 |
-
self.json_report
|
|
|
75 |
self.logger.info(f'Resetting temperature from {self.adjust_temp} to {self.starting_temp}')
|
76 |
self.adjust_temp = self.starting_temp
|
77 |
self.config['temperature'] = self.starting_temp
|
78 |
|
79 |
def _build_model_chain_parser(self):
|
80 |
# Instantiate the LLM class for Google Gemini
|
81 |
-
self.llm_model = ChatGoogleGenerativeAI(model=self.model_name
|
82 |
-
|
83 |
-
|
|
|
|
|
84 |
# self.llm_model = VertexAI(model='gemini-1.0-pro',
|
85 |
# max_output_tokens=self.config.get('max_output_tokens'),
|
86 |
# top_p=self.config.get('top_p'))
|
@@ -101,7 +116,8 @@ class GoogleGeminiHandler:
|
|
101 |
def call_llm_api_GoogleGemini(self, prompt_template, json_report, paths):
|
102 |
_____, ____, _, __, ___, json_file_path_wiki, txt_file_path_ind_prompt = paths
|
103 |
self.json_report = json_report
|
104 |
-
self.json_report
|
|
|
105 |
self.monitor.start_monitoring_usage()
|
106 |
nt_in = 0
|
107 |
nt_out = 0
|
@@ -110,9 +126,9 @@ class GoogleGeminiHandler:
|
|
110 |
while ind < self.MAX_RETRIES:
|
111 |
ind += 1
|
112 |
try:
|
113 |
-
model_kwargs = {"temperature": self.adjust_temp}
|
114 |
# Invoke the chain to generate prompt text
|
115 |
-
response = self.chain.invoke({"query": prompt_template
|
116 |
|
117 |
# Use retry_parser to parse the response with retry logic
|
118 |
output = self.retry_parser.parse_with_prompt(response, prompt_value=prompt_template)
|
@@ -131,7 +147,8 @@ class GoogleGeminiHandler:
|
|
131 |
else:
|
132 |
self.monitor.stop_inference_timer() # Starts tool timer too
|
133 |
|
134 |
-
json_report
|
|
|
135 |
output_WFO, WFO_record, output_GEO, GEO_record = run_tools(output, self.tool_WFO, self.tool_GEO, self.tool_wikipedia, json_file_path_wiki)
|
136 |
|
137 |
save_individual_prompt(sanitize_prompt(prompt_template), txt_file_path_ind_prompt)
|
@@ -143,7 +160,8 @@ class GoogleGeminiHandler:
|
|
143 |
if self.adjust_temp != self.starting_temp:
|
144 |
self._reset_config()
|
145 |
|
146 |
-
json_report
|
|
|
147 |
return output, nt_in, nt_out, WFO_record, GEO_record, usage_report
|
148 |
|
149 |
except Exception as e:
|
@@ -153,14 +171,16 @@ class GoogleGeminiHandler:
|
|
153 |
time.sleep(self.RETRY_DELAY)
|
154 |
|
155 |
self.logger.info(f"Failed to extract valid JSON after [{ind}] attempts")
|
156 |
-
self.json_report
|
|
|
157 |
|
158 |
self.monitor.stop_inference_timer() # Starts tool timer too
|
159 |
|
160 |
usage_report = self.monitor.stop_monitoring_report_usage()
|
161 |
self._reset_config()
|
162 |
|
163 |
-
json_report
|
|
|
164 |
return None, nt_in, nt_out, None, None, usage_report
|
165 |
|
166 |
|
|
|
20 |
VENDOR = 'google'
|
21 |
STARTING_TEMP = 0.5
|
22 |
|
23 |
+
def __init__(self, cfg, logger, model_name, JSON_dict_structure, config_vals_for_permutation):
|
24 |
self.cfg = cfg
|
25 |
self.tool_WFO = self.cfg['leafmachine']['project']['tool_WFO']
|
26 |
self.tool_GEO = self.cfg['leafmachine']['project']['tool_GEO']
|
|
|
30 |
self.model_name = model_name
|
31 |
self.JSON_dict_structure = JSON_dict_structure
|
32 |
|
33 |
+
self.config_vals_for_permutation = config_vals_for_permutation
|
34 |
+
|
|
|
|
|
35 |
self.monitor = SystemLoadMonitor(logger)
|
36 |
|
37 |
self.parser = JsonOutputParser()
|
|
|
48 |
def _set_config(self):
|
49 |
# os.environ['GOOGLE_API_KEY'] # Must be set too for the retry call, set in VoucherVision class along with other API Keys
|
50 |
# vertexai.init(project=os.environ['PALM_PROJECT_ID'], location=os.environ['PALM_LOCATION'])
|
51 |
+
if self.config_vals_for_permutation:
|
52 |
+
self.starting_temp = float(self.config_vals_for_permutation.get('google').get('temperature'))
|
53 |
+
self.config = {
|
54 |
+
'max_output_tokens': self.config_vals_for_permutation.get('google').get('max_output_tokens'),
|
55 |
+
'temperature': self.starting_temp,
|
56 |
+
'top_p': self.config_vals_for_permutation.get('google').get('top_p'),
|
57 |
+
}
|
58 |
+
else:
|
59 |
+
self.starting_temp = float(self.STARTING_TEMP)
|
60 |
+
self.config = {
|
61 |
"max_output_tokens": 1024,
|
62 |
"temperature": self.starting_temp,
|
63 |
+
"top_p": 1.0,
|
64 |
}
|
65 |
+
|
66 |
+
self.temp_increment = float(0.2)
|
67 |
+
self.adjust_temp = self.starting_temp
|
68 |
+
|
69 |
self.safety_settings = {
|
70 |
HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE,
|
71 |
HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_NONE,
|
|
|
76 |
|
77 |
def _adjust_config(self):
|
78 |
new_temp = self.adjust_temp + self.temp_increment
|
79 |
+
if self.json_report:
|
80 |
+
self.json_report.set_text(text_main=f'Incrementing temperature from {self.adjust_temp} to {new_temp}')
|
81 |
self.logger.info(f'Incrementing temperature from {self.adjust_temp} to {new_temp}')
|
82 |
self.adjust_temp += self.temp_increment
|
83 |
self.config['temperature'] = self.adjust_temp
|
84 |
|
85 |
def _reset_config(self):
|
86 |
+
if self.json_report:
|
87 |
+
self.json_report.set_text(text_main=f'Resetting temperature from {self.adjust_temp} to {self.starting_temp}')
|
88 |
self.logger.info(f'Resetting temperature from {self.adjust_temp} to {self.starting_temp}')
|
89 |
self.adjust_temp = self.starting_temp
|
90 |
self.config['temperature'] = self.starting_temp
|
91 |
|
92 |
def _build_model_chain_parser(self):
|
93 |
# Instantiate the LLM class for Google Gemini
|
94 |
+
self.llm_model = ChatGoogleGenerativeAI(model=self.model_name,
|
95 |
+
max_output_tokens=self.config.get('max_output_tokens'),
|
96 |
+
top_p=self.config.get('top_p'),
|
97 |
+
temperature=self.config.get('temperature')
|
98 |
+
)
|
99 |
# self.llm_model = VertexAI(model='gemini-1.0-pro',
|
100 |
# max_output_tokens=self.config.get('max_output_tokens'),
|
101 |
# top_p=self.config.get('top_p'))
|
|
|
116 |
def call_llm_api_GoogleGemini(self, prompt_template, json_report, paths):
|
117 |
_____, ____, _, __, ___, json_file_path_wiki, txt_file_path_ind_prompt = paths
|
118 |
self.json_report = json_report
|
119 |
+
if self.json_report:
|
120 |
+
self.json_report.set_text(text_main=f'Sending request to {self.model_name}')
|
121 |
self.monitor.start_monitoring_usage()
|
122 |
nt_in = 0
|
123 |
nt_out = 0
|
|
|
126 |
while ind < self.MAX_RETRIES:
|
127 |
ind += 1
|
128 |
try:
|
129 |
+
# model_kwargs = {"temperature": self.adjust_temp}
|
130 |
# Invoke the chain to generate prompt text
|
131 |
+
response = self.chain.invoke({"query": prompt_template})#, "model_kwargs": model_kwargs})
|
132 |
|
133 |
# Use retry_parser to parse the response with retry logic
|
134 |
output = self.retry_parser.parse_with_prompt(response, prompt_value=prompt_template)
|
|
|
147 |
else:
|
148 |
self.monitor.stop_inference_timer() # Starts tool timer too
|
149 |
|
150 |
+
if self.json_report:
|
151 |
+
self.json_report.set_text(text_main=f'Working on WFO, Geolocation, Links')
|
152 |
output_WFO, WFO_record, output_GEO, GEO_record = run_tools(output, self.tool_WFO, self.tool_GEO, self.tool_wikipedia, json_file_path_wiki)
|
153 |
|
154 |
save_individual_prompt(sanitize_prompt(prompt_template), txt_file_path_ind_prompt)
|
|
|
160 |
if self.adjust_temp != self.starting_temp:
|
161 |
self._reset_config()
|
162 |
|
163 |
+
if self.json_report:
|
164 |
+
self.json_report.set_text(text_main=f'LLM call successful')
|
165 |
return output, nt_in, nt_out, WFO_record, GEO_record, usage_report
|
166 |
|
167 |
except Exception as e:
|
|
|
171 |
time.sleep(self.RETRY_DELAY)
|
172 |
|
173 |
self.logger.info(f"Failed to extract valid JSON after [{ind}] attempts")
|
174 |
+
if self.json_report:
|
175 |
+
self.json_report.set_text(text_main=f'Failed to extract valid JSON after [{ind}] attempts')
|
176 |
|
177 |
self.monitor.stop_inference_timer() # Starts tool timer too
|
178 |
|
179 |
usage_report = self.monitor.stop_monitoring_report_usage()
|
180 |
self._reset_config()
|
181 |
|
182 |
+
if self.json_report:
|
183 |
+
self.json_report.set_text(text_main=f'LLM call failed')
|
184 |
return None, nt_in, nt_out, None, None, usage_report
|
185 |
|
186 |
|
vouchervision/LLM_GooglePalm2.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1 |
-
import os, time, json
|
2 |
# import vertexai
|
3 |
from vertexai.language_models import TextGenerationModel
|
4 |
from vertexai.generative_models._generative_models import HarmCategory, HarmBlockThreshold
|
@@ -10,6 +10,8 @@ from langchain.prompts import PromptTemplate
|
|
10 |
from langchain_core.output_parsers import JsonOutputParser
|
11 |
# from langchain_google_genai import ChatGoogleGenerativeAI
|
12 |
from langchain_google_vertexai import VertexAI
|
|
|
|
|
13 |
|
14 |
from vouchervision.utils_LLM import SystemLoadMonitor, run_tools, count_tokens, save_individual_prompt, sanitize_prompt
|
15 |
from vouchervision.utils_LLM_JSON_validation import validate_and_align_JSON_keys_with_template
|
@@ -31,7 +33,7 @@ class GooglePalm2Handler:
|
|
31 |
VENDOR = 'google'
|
32 |
STARTING_TEMP = 0.5
|
33 |
|
34 |
-
def __init__(self, cfg, logger, model_name, JSON_dict_structure):
|
35 |
self.cfg = cfg
|
36 |
self.tool_WFO = self.cfg['leafmachine']['project']['tool_WFO']
|
37 |
self.tool_GEO = self.cfg['leafmachine']['project']['tool_GEO']
|
@@ -41,9 +43,9 @@ class GooglePalm2Handler:
|
|
41 |
self.model_name = model_name
|
42 |
self.JSON_dict_structure = JSON_dict_structure
|
43 |
|
44 |
-
self.
|
45 |
-
|
46 |
-
|
47 |
|
48 |
self.monitor = SystemLoadMonitor(logger)
|
49 |
|
@@ -59,12 +61,26 @@ class GooglePalm2Handler:
|
|
59 |
|
60 |
def _set_config(self):
|
61 |
# vertexai.init(project=os.environ['PALM_PROJECT_ID'], location=os.environ['PALM_LOCATION'])
|
62 |
-
self.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
63 |
"max_output_tokens": 1024,
|
64 |
"temperature": self.starting_temp,
|
|
|
65 |
"top_p": 1.0,
|
66 |
-
"top_k": 40,
|
67 |
}
|
|
|
|
|
|
|
|
|
68 |
self.safety_settings = {
|
69 |
HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE,
|
70 |
HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_NONE,
|
@@ -75,13 +91,15 @@ class GooglePalm2Handler:
|
|
75 |
|
76 |
def _adjust_config(self):
|
77 |
new_temp = self.adjust_temp + self.temp_increment
|
78 |
-
self.json_report
|
|
|
79 |
self.logger.info(f'Incrementing temperature from {self.adjust_temp} to {new_temp}')
|
80 |
self.adjust_temp += self.temp_increment
|
81 |
self.config['temperature'] = self.adjust_temp
|
82 |
|
83 |
def _reset_config(self):
|
84 |
-
self.json_report
|
|
|
85 |
self.logger.info(f'Resetting temperature from {self.adjust_temp} to {self.starting_temp}')
|
86 |
self.adjust_temp = self.starting_temp
|
87 |
self.config['temperature'] = self.starting_temp
|
@@ -89,7 +107,11 @@ class GooglePalm2Handler:
|
|
89 |
def _build_model_chain_parser(self):
|
90 |
# Instantiate the parser and the retry parser
|
91 |
# self.llm_model = ChatGoogleGenerativeAI(model=self.model_name)
|
92 |
-
self.llm_model = VertexAI(model=self.model_name
|
|
|
|
|
|
|
|
|
93 |
|
94 |
self.retry_parser = RetryWithErrorOutputParser.from_llm(
|
95 |
parser=self.parser,
|
@@ -105,6 +127,7 @@ class GooglePalm2Handler:
|
|
105 |
response = model.predict(prompt_text.text,
|
106 |
max_output_tokens=self.config.get('max_output_tokens'),
|
107 |
temperature=self.config.get('temperature'),
|
|
|
108 |
top_p=self.config.get('top_p'))
|
109 |
# model = GenerativeModel(self.model_name)
|
110 |
|
@@ -115,7 +138,8 @@ class GooglePalm2Handler:
|
|
115 |
def call_llm_api_GooglePalm2(self, prompt_template, json_report, paths):
|
116 |
_____, ____, _, __, ___, json_file_path_wiki, txt_file_path_ind_prompt = paths
|
117 |
self.json_report = json_report
|
118 |
-
|
|
|
119 |
self.monitor.start_monitoring_usage()
|
120 |
nt_in = 0
|
121 |
nt_out = 0
|
@@ -124,12 +148,23 @@ class GooglePalm2Handler:
|
|
124 |
while ind < self.MAX_RETRIES:
|
125 |
ind += 1
|
126 |
try:
|
127 |
-
model_kwargs = {"temperature": self.adjust_temp}
|
128 |
# Invoke the chain to generate prompt text
|
129 |
-
response = self.chain.invoke({"query": prompt_template
|
130 |
|
131 |
# Use retry_parser to parse the response with retry logic
|
132 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
133 |
|
134 |
if output is None:
|
135 |
self.logger.error(f'[Attempt {ind}] Failed to extract JSON from:\n{response}')
|
@@ -144,8 +179,9 @@ class GooglePalm2Handler:
|
|
144 |
self._adjust_config()
|
145 |
else:
|
146 |
self.monitor.stop_inference_timer() # Starts tool timer too
|
147 |
-
|
148 |
-
json_report
|
|
|
149 |
output_WFO, WFO_record, output_GEO, GEO_record = run_tools(output, self.tool_WFO, self.tool_GEO, self.tool_wikipedia, json_file_path_wiki)
|
150 |
|
151 |
save_individual_prompt(sanitize_prompt(prompt_template), txt_file_path_ind_prompt)
|
@@ -157,7 +193,8 @@ class GooglePalm2Handler:
|
|
157 |
if self.adjust_temp != self.starting_temp:
|
158 |
self._reset_config()
|
159 |
|
160 |
-
json_report
|
|
|
161 |
return output, nt_in, nt_out, WFO_record, GEO_record, usage_report
|
162 |
|
163 |
except Exception as e:
|
@@ -167,11 +204,19 @@ class GooglePalm2Handler:
|
|
167 |
time.sleep(self.RETRY_DELAY)
|
168 |
|
169 |
self.logger.info(f"Failed to extract valid JSON after [{ind}] attempts")
|
170 |
-
self.json_report
|
|
|
171 |
|
172 |
self.monitor.stop_inference_timer() # Starts tool timer too
|
173 |
usage_report = self.monitor.stop_monitoring_report_usage()
|
174 |
self._reset_config()
|
175 |
|
176 |
-
json_report
|
177 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os, time, json, typing
|
2 |
# import vertexai
|
3 |
from vertexai.language_models import TextGenerationModel
|
4 |
from vertexai.generative_models._generative_models import HarmCategory, HarmBlockThreshold
|
|
|
10 |
from langchain_core.output_parsers import JsonOutputParser
|
11 |
# from langchain_google_genai import ChatGoogleGenerativeAI
|
12 |
from langchain_google_vertexai import VertexAI
|
13 |
+
from langchain_core.messages import BaseMessage, HumanMessage
|
14 |
+
from langchain_core.prompt_values import PromptValue as BasePromptValue
|
15 |
|
16 |
from vouchervision.utils_LLM import SystemLoadMonitor, run_tools, count_tokens, save_individual_prompt, sanitize_prompt
|
17 |
from vouchervision.utils_LLM_JSON_validation import validate_and_align_JSON_keys_with_template
|
|
|
33 |
VENDOR = 'google'
|
34 |
STARTING_TEMP = 0.5
|
35 |
|
36 |
+
def __init__(self, cfg, logger, model_name, JSON_dict_structure, config_vals_for_permutation):
|
37 |
self.cfg = cfg
|
38 |
self.tool_WFO = self.cfg['leafmachine']['project']['tool_WFO']
|
39 |
self.tool_GEO = self.cfg['leafmachine']['project']['tool_GEO']
|
|
|
43 |
self.model_name = model_name
|
44 |
self.JSON_dict_structure = JSON_dict_structure
|
45 |
|
46 |
+
self.config_vals_for_permutation = config_vals_for_permutation
|
47 |
+
|
48 |
+
|
49 |
|
50 |
self.monitor = SystemLoadMonitor(logger)
|
51 |
|
|
|
61 |
|
62 |
def _set_config(self):
|
63 |
# vertexai.init(project=os.environ['PALM_PROJECT_ID'], location=os.environ['PALM_LOCATION'])
|
64 |
+
if self.config_vals_for_permutation:
|
65 |
+
self.starting_temp = float(self.config_vals_for_permutation.get('google').get('temperature'))
|
66 |
+
self.config = {
|
67 |
+
'max_output_tokens': self.config_vals_for_permutation.get('google').get('max_output_tokens'),
|
68 |
+
'temperature': self.starting_temp,
|
69 |
+
'top_k': self.config_vals_for_permutation.get('google').get('top_k'),
|
70 |
+
'top_p': self.config_vals_for_permutation.get('google').get('top_p'),
|
71 |
+
}
|
72 |
+
else:
|
73 |
+
self.starting_temp = float(self.STARTING_TEMP)
|
74 |
+
self.config = {
|
75 |
"max_output_tokens": 1024,
|
76 |
"temperature": self.starting_temp,
|
77 |
+
"top_k": 1,
|
78 |
"top_p": 1.0,
|
|
|
79 |
}
|
80 |
+
|
81 |
+
self.temp_increment = float(0.2)
|
82 |
+
self.adjust_temp = self.starting_temp
|
83 |
+
|
84 |
self.safety_settings = {
|
85 |
HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE,
|
86 |
HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_NONE,
|
|
|
91 |
|
92 |
def _adjust_config(self):
|
93 |
new_temp = self.adjust_temp + self.temp_increment
|
94 |
+
if self.json_report:
|
95 |
+
self.json_report.set_text(text_main=f'Incrementing temperature from {self.adjust_temp} to {new_temp}')
|
96 |
self.logger.info(f'Incrementing temperature from {self.adjust_temp} to {new_temp}')
|
97 |
self.adjust_temp += self.temp_increment
|
98 |
self.config['temperature'] = self.adjust_temp
|
99 |
|
100 |
def _reset_config(self):
|
101 |
+
if self.json_report:
|
102 |
+
self.json_report.set_text(text_main=f'Resetting temperature from {self.adjust_temp} to {self.starting_temp}')
|
103 |
self.logger.info(f'Resetting temperature from {self.adjust_temp} to {self.starting_temp}')
|
104 |
self.adjust_temp = self.starting_temp
|
105 |
self.config['temperature'] = self.starting_temp
|
|
|
107 |
def _build_model_chain_parser(self):
|
108 |
# Instantiate the parser and the retry parser
|
109 |
# self.llm_model = ChatGoogleGenerativeAI(model=self.model_name)
|
110 |
+
self.llm_model = VertexAI(model=self.model_name,
|
111 |
+
max_output_tokens=self.config.get('max_output_tokens'),
|
112 |
+
temperature=self.config.get('temperature'),
|
113 |
+
top_k=self.config.get('top_k'),
|
114 |
+
top_p=self.config.get('top_p'))
|
115 |
|
116 |
self.retry_parser = RetryWithErrorOutputParser.from_llm(
|
117 |
parser=self.parser,
|
|
|
127 |
response = model.predict(prompt_text.text,
|
128 |
max_output_tokens=self.config.get('max_output_tokens'),
|
129 |
temperature=self.config.get('temperature'),
|
130 |
+
top_k=self.config.get('top_k'),
|
131 |
top_p=self.config.get('top_p'))
|
132 |
# model = GenerativeModel(self.model_name)
|
133 |
|
|
|
138 |
def call_llm_api_GooglePalm2(self, prompt_template, json_report, paths):
|
139 |
_____, ____, _, __, ___, json_file_path_wiki, txt_file_path_ind_prompt = paths
|
140 |
self.json_report = json_report
|
141 |
+
if json_report:
|
142 |
+
self.json_report.set_text(text_main=f'Sending request to {self.model_name}')
|
143 |
self.monitor.start_monitoring_usage()
|
144 |
nt_in = 0
|
145 |
nt_out = 0
|
|
|
148 |
while ind < self.MAX_RETRIES:
|
149 |
ind += 1
|
150 |
try:
|
151 |
+
# model_kwargs = {"temperature": self.adjust_temp}
|
152 |
# Invoke the chain to generate prompt text
|
153 |
+
response = self.chain.invoke({"query": prompt_template})#, "model_kwargs": model_kwargs})
|
154 |
|
155 |
# Use retry_parser to parse the response with retry logic
|
156 |
+
try:
|
157 |
+
output = self.retry_parser.parse_with_prompt(response, prompt_value=PromptValue(prompt_template))
|
158 |
+
except:
|
159 |
+
try:
|
160 |
+
output = self.retry_parser.parse_with_prompt(response, prompt_value=prompt_template)
|
161 |
+
except:
|
162 |
+
try:
|
163 |
+
output = json.loads(response)
|
164 |
+
except Exception as e:
|
165 |
+
print(e)
|
166 |
+
output = None
|
167 |
+
|
168 |
|
169 |
if output is None:
|
170 |
self.logger.error(f'[Attempt {ind}] Failed to extract JSON from:\n{response}')
|
|
|
179 |
self._adjust_config()
|
180 |
else:
|
181 |
self.monitor.stop_inference_timer() # Starts tool timer too
|
182 |
+
|
183 |
+
if self.json_report:
|
184 |
+
self.json_report.set_text(text_main=f'Working on WFO, Geolocation, Links')
|
185 |
output_WFO, WFO_record, output_GEO, GEO_record = run_tools(output, self.tool_WFO, self.tool_GEO, self.tool_wikipedia, json_file_path_wiki)
|
186 |
|
187 |
save_individual_prompt(sanitize_prompt(prompt_template), txt_file_path_ind_prompt)
|
|
|
193 |
if self.adjust_temp != self.starting_temp:
|
194 |
self._reset_config()
|
195 |
|
196 |
+
if self.json_report:
|
197 |
+
self.json_report.set_text(text_main=f'LLM call successful')
|
198 |
return output, nt_in, nt_out, WFO_record, GEO_record, usage_report
|
199 |
|
200 |
except Exception as e:
|
|
|
204 |
time.sleep(self.RETRY_DELAY)
|
205 |
|
206 |
self.logger.info(f"Failed to extract valid JSON after [{ind}] attempts")
|
207 |
+
if self.json_report:
|
208 |
+
self.json_report.set_text(text_main=f'Failed to extract valid JSON after [{ind}] attempts')
|
209 |
|
210 |
self.monitor.stop_inference_timer() # Starts tool timer too
|
211 |
usage_report = self.monitor.stop_monitoring_report_usage()
|
212 |
self._reset_config()
|
213 |
|
214 |
+
if self.json_report:
|
215 |
+
self.json_report.set_text(text_main=f'LLM call failed')
|
216 |
+
return None, nt_in, nt_out, None, None, usage_report
|
217 |
+
|
218 |
+
class PromptValue(BasePromptValue):
|
219 |
+
prompt_str: str
|
220 |
+
|
221 |
+
def to_string(self) -> str:
|
222 |
+
return self.prompt_str
|
vouchervision/LLM_MistralAI.py
CHANGED
@@ -11,12 +11,12 @@ from vouchervision.utils_LLM_JSON_validation import validate_and_align_JSON_keys
|
|
11 |
class MistralHandler:
|
12 |
RETRY_DELAY = 2 # Wait 10 seconds before retrying
|
13 |
MAX_RETRIES = 5 # Maximum number of retries
|
14 |
-
STARTING_TEMP = 0.
|
15 |
TOKENIZER_NAME = None
|
16 |
VENDOR = 'mistral'
|
17 |
RANDOM_SEED = 2023
|
18 |
|
19 |
-
def __init__(self, cfg, logger, model_name, JSON_dict_structure):
|
20 |
self.cfg = cfg
|
21 |
self.tool_WFO = self.cfg['leafmachine']['project']['tool_WFO']
|
22 |
self.tool_GEO = self.cfg['leafmachine']['project']['tool_GEO']
|
@@ -27,10 +27,9 @@ class MistralHandler:
|
|
27 |
self.has_GPU = torch.cuda.is_available()
|
28 |
self.model_name = model_name
|
29 |
self.JSON_dict_structure = JSON_dict_structure
|
30 |
-
self.starting_temp = float(self.STARTING_TEMP)
|
31 |
-
self.temp_increment = float(0.2)
|
32 |
-
self.adjust_temp = self.starting_temp
|
33 |
|
|
|
|
|
34 |
# Set up a parser
|
35 |
self.parser = JsonOutputParser()
|
36 |
|
@@ -44,25 +43,45 @@ class MistralHandler:
|
|
44 |
self._set_config()
|
45 |
|
46 |
def _set_config(self):
|
47 |
-
self.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
48 |
'temperature': self.starting_temp,
|
49 |
'random_seed': self.RANDOM_SEED,
|
50 |
'safe_mode': False,
|
51 |
-
'top_p':
|
52 |
-
|
|
|
|
|
|
|
|
|
|
|
53 |
self._build_model_chain_parser()
|
54 |
|
55 |
|
56 |
def _adjust_config(self):
|
57 |
new_temp = self.adjust_temp + self.temp_increment
|
58 |
self.config['random_seed'] = random.randint(1, 1000)
|
59 |
-
self.json_report
|
|
|
60 |
self.logger.info(f'Incrementing temperature from {self.adjust_temp} to {new_temp} and random_seed to {self.config.get("random_seed")}')
|
61 |
self.adjust_temp += self.temp_increment
|
62 |
self.config['temperature'] = self.adjust_temp
|
63 |
|
64 |
def _reset_config(self):
|
65 |
-
self.json_report
|
|
|
66 |
self.logger.info(f'Incrementing temperature from {self.adjust_temp} to {self.starting_temp} and random_seed to {self.RANDOM_SEED}')
|
67 |
self.adjust_temp = self.starting_temp
|
68 |
self.config['temperature'] = self.starting_temp
|
@@ -74,7 +93,9 @@ class MistralHandler:
|
|
74 |
model=self.model_name,
|
75 |
max_tokens=self.config.get('max_tokens'),
|
76 |
safe_mode=self.config.get('safe_mode'),
|
77 |
-
top_p=self.config.get('top_p')
|
|
|
|
|
78 |
|
79 |
# Set up the retry parser with the runnable
|
80 |
self.retry_parser = RetryWithErrorOutputParser.from_llm(parser=self.parser, llm=self.llm_model, max_retries=self.MAX_RETRIES)
|
@@ -85,7 +106,8 @@ class MistralHandler:
|
|
85 |
_____, ____, _, __, ___, json_file_path_wiki, txt_file_path_ind_prompt = paths
|
86 |
|
87 |
self.json_report = json_report
|
88 |
-
self.json_report
|
|
|
89 |
self.monitor.start_monitoring_usage()
|
90 |
nt_in = 0
|
91 |
nt_out = 0
|
@@ -94,10 +116,10 @@ class MistralHandler:
|
|
94 |
while ind < self.MAX_RETRIES:
|
95 |
ind += 1
|
96 |
try:
|
97 |
-
model_kwargs = {"temperature": self.adjust_temp, "random_seed": self.config.get("random_seed")}
|
98 |
|
99 |
# Invoke the chain to generate prompt text
|
100 |
-
response = self.chain.invoke({"query": prompt_template
|
101 |
|
102 |
# Use retry_parser to parse the response with retry logic
|
103 |
output = self.retry_parser.parse_with_prompt(response.content, prompt_value=prompt_template)
|
@@ -115,8 +137,9 @@ class MistralHandler:
|
|
115 |
self._adjust_config()
|
116 |
else:
|
117 |
self.monitor.stop_inference_timer() # Starts tool timer too
|
118 |
-
|
119 |
-
json_report
|
|
|
120 |
output_WFO, WFO_record, output_GEO, GEO_record = run_tools(output, self.tool_WFO, self.tool_GEO, self.tool_wikipedia, json_file_path_wiki)
|
121 |
|
122 |
save_individual_prompt(sanitize_prompt(prompt_template), txt_file_path_ind_prompt)
|
@@ -128,7 +151,8 @@ class MistralHandler:
|
|
128 |
if self.adjust_temp != self.starting_temp:
|
129 |
self._reset_config()
|
130 |
|
131 |
-
json_report
|
|
|
132 |
return output, nt_in, nt_out, WFO_record, GEO_record, usage_report
|
133 |
|
134 |
except Exception as e:
|
@@ -138,11 +162,13 @@ class MistralHandler:
|
|
138 |
time.sleep(self.RETRY_DELAY)
|
139 |
|
140 |
self.logger.info(f"Failed to extract valid JSON after [{ind}] attempts")
|
141 |
-
self.json_report
|
|
|
142 |
|
143 |
self.monitor.stop_inference_timer() # Starts tool timer too
|
144 |
usage_report = self.monitor.stop_monitoring_report_usage()
|
145 |
self._reset_config()
|
146 |
-
json_report
|
|
|
147 |
|
148 |
return None, nt_in, nt_out, None, None, usage_report
|
|
|
11 |
class MistralHandler:
|
12 |
RETRY_DELAY = 2 # Wait 10 seconds before retrying
|
13 |
MAX_RETRIES = 5 # Maximum number of retries
|
14 |
+
STARTING_TEMP = 0.5 #0.01
|
15 |
TOKENIZER_NAME = None
|
16 |
VENDOR = 'mistral'
|
17 |
RANDOM_SEED = 2023
|
18 |
|
19 |
+
def __init__(self, cfg, logger, model_name, JSON_dict_structure, config_vals_for_permutation):
|
20 |
self.cfg = cfg
|
21 |
self.tool_WFO = self.cfg['leafmachine']['project']['tool_WFO']
|
22 |
self.tool_GEO = self.cfg['leafmachine']['project']['tool_GEO']
|
|
|
27 |
self.has_GPU = torch.cuda.is_available()
|
28 |
self.model_name = model_name
|
29 |
self.JSON_dict_structure = JSON_dict_structure
|
|
|
|
|
|
|
30 |
|
31 |
+
self.config_vals_for_permutation = config_vals_for_permutation
|
32 |
+
|
33 |
# Set up a parser
|
34 |
self.parser = JsonOutputParser()
|
35 |
|
|
|
43 |
self._set_config()
|
44 |
|
45 |
def _set_config(self):
|
46 |
+
if self.config_vals_for_permutation:
|
47 |
+
self.starting_temp = float(self.config_vals_for_permutation.get('mistral').get('temperature'))
|
48 |
+
self.config = {
|
49 |
+
'max_tokens': self.config_vals_for_permutation.get('mistral').get('max_tokens'),
|
50 |
+
'temperature': self.starting_temp,
|
51 |
+
'top_p': self.config_vals_for_permutation.get('mistral').get('top_p'),
|
52 |
+
'top_k': self.config_vals_for_permutation.get('mistral').get('top_k'),
|
53 |
+
'safe_mode': self.config_vals_for_permutation.get('mistral').get('safe_mode'),
|
54 |
+
'random_seed': self.config_vals_for_permutation.get('mistral').get('random_seed'),
|
55 |
+
}
|
56 |
+
else:
|
57 |
+
self.starting_temp = float(self.STARTING_TEMP)
|
58 |
+
self.config = {
|
59 |
+
'max_tokens': 1024,
|
60 |
'temperature': self.starting_temp,
|
61 |
'random_seed': self.RANDOM_SEED,
|
62 |
'safe_mode': False,
|
63 |
+
'top_p': 0.5,
|
64 |
+
'top_k': 0.5,
|
65 |
+
}
|
66 |
+
|
67 |
+
self.temp_increment = float(0.2)
|
68 |
+
self.adjust_temp = self.starting_temp
|
69 |
+
|
70 |
self._build_model_chain_parser()
|
71 |
|
72 |
|
73 |
def _adjust_config(self):
|
74 |
new_temp = self.adjust_temp + self.temp_increment
|
75 |
self.config['random_seed'] = random.randint(1, 1000)
|
76 |
+
if self.json_report:
|
77 |
+
self.json_report.set_text(text_main=f'Incrementing temperature from {self.adjust_temp} to {new_temp} and random_seed to {self.config.get("random_seed")}')
|
78 |
self.logger.info(f'Incrementing temperature from {self.adjust_temp} to {new_temp} and random_seed to {self.config.get("random_seed")}')
|
79 |
self.adjust_temp += self.temp_increment
|
80 |
self.config['temperature'] = self.adjust_temp
|
81 |
|
82 |
def _reset_config(self):
|
83 |
+
if self.json_report:
|
84 |
+
self.json_report.set_text(text_main=f'Resetting temperature from {self.adjust_temp} to {self.starting_temp} and random_seed to {self.RANDOM_SEED}')
|
85 |
self.logger.info(f'Incrementing temperature from {self.adjust_temp} to {self.starting_temp} and random_seed to {self.RANDOM_SEED}')
|
86 |
self.adjust_temp = self.starting_temp
|
87 |
self.config['temperature'] = self.starting_temp
|
|
|
93 |
model=self.model_name,
|
94 |
max_tokens=self.config.get('max_tokens'),
|
95 |
safe_mode=self.config.get('safe_mode'),
|
96 |
+
top_p=self.config.get('top_p'),
|
97 |
+
top_k=self.config.get('top_k'),
|
98 |
+
)
|
99 |
|
100 |
# Set up the retry parser with the runnable
|
101 |
self.retry_parser = RetryWithErrorOutputParser.from_llm(parser=self.parser, llm=self.llm_model, max_retries=self.MAX_RETRIES)
|
|
|
106 |
_____, ____, _, __, ___, json_file_path_wiki, txt_file_path_ind_prompt = paths
|
107 |
|
108 |
self.json_report = json_report
|
109 |
+
if self.json_report:
|
110 |
+
self.json_report.set_text(text_main=f'Sending request to {self.model_name}')
|
111 |
self.monitor.start_monitoring_usage()
|
112 |
nt_in = 0
|
113 |
nt_out = 0
|
|
|
116 |
while ind < self.MAX_RETRIES:
|
117 |
ind += 1
|
118 |
try:
|
119 |
+
# model_kwargs = {"temperature": self.adjust_temp, "random_seed": self.config.get("random_seed")}
|
120 |
|
121 |
# Invoke the chain to generate prompt text
|
122 |
+
response = self.chain.invoke({"query": prompt_template})#, "model_kwargs": model_kwargs})
|
123 |
|
124 |
# Use retry_parser to parse the response with retry logic
|
125 |
output = self.retry_parser.parse_with_prompt(response.content, prompt_value=prompt_template)
|
|
|
137 |
self._adjust_config()
|
138 |
else:
|
139 |
self.monitor.stop_inference_timer() # Starts tool timer too
|
140 |
+
|
141 |
+
if self.json_report:
|
142 |
+
self.json_report.set_text(text_main=f'Working on WFO, Geolocation, Links')
|
143 |
output_WFO, WFO_record, output_GEO, GEO_record = run_tools(output, self.tool_WFO, self.tool_GEO, self.tool_wikipedia, json_file_path_wiki)
|
144 |
|
145 |
save_individual_prompt(sanitize_prompt(prompt_template), txt_file_path_ind_prompt)
|
|
|
151 |
if self.adjust_temp != self.starting_temp:
|
152 |
self._reset_config()
|
153 |
|
154 |
+
if self.json_report:
|
155 |
+
self.json_report.set_text(text_main=f'LLM call successful')
|
156 |
return output, nt_in, nt_out, WFO_record, GEO_record, usage_report
|
157 |
|
158 |
except Exception as e:
|
|
|
162 |
time.sleep(self.RETRY_DELAY)
|
163 |
|
164 |
self.logger.info(f"Failed to extract valid JSON after [{ind}] attempts")
|
165 |
+
if self.json_report:
|
166 |
+
self.json_report.set_text(text_main=f'Failed to extract valid JSON after [{ind}] attempts')
|
167 |
|
168 |
self.monitor.stop_inference_timer() # Starts tool timer too
|
169 |
usage_report = self.monitor.stop_monitoring_report_usage()
|
170 |
self._reset_config()
|
171 |
+
if self.json_report:
|
172 |
+
self.json_report.set_text(text_main=f'LLM call failed')
|
173 |
|
174 |
return None, nt_in, nt_out, None, None, usage_report
|
vouchervision/LLM_OpenAI.py
CHANGED
@@ -11,11 +11,11 @@ from vouchervision.utils_LLM_JSON_validation import validate_and_align_JSON_keys
|
|
11 |
class OpenAIHandler:
|
12 |
RETRY_DELAY = 10 # Wait 10 seconds before retrying
|
13 |
MAX_RETRIES = 3 # Maximum number of retries
|
14 |
-
STARTING_TEMP = 0.5
|
15 |
TOKENIZER_NAME = 'gpt-4'
|
16 |
VENDOR = 'openai'
|
17 |
|
18 |
-
def __init__(self, cfg, logger, model_name, JSON_dict_structure, is_azure, llm_object):
|
19 |
self.cfg = cfg
|
20 |
self.tool_WFO = self.cfg['leafmachine']['project']['tool_WFO']
|
21 |
self.tool_GEO = self.cfg['leafmachine']['project']['tool_GEO']
|
@@ -26,14 +26,13 @@ class OpenAIHandler:
|
|
26 |
self.JSON_dict_structure = JSON_dict_structure
|
27 |
self.is_azure = is_azure
|
28 |
self.llm_object = llm_object
|
29 |
-
self.name_parts = self.model_name.split('-')
|
30 |
|
31 |
self.monitor = SystemLoadMonitor(logger)
|
32 |
self.has_GPU = torch.cuda.is_available()
|
33 |
|
34 |
-
|
35 |
-
self.
|
36 |
-
self.adjust_temp = self.starting_temp
|
37 |
|
38 |
# Set up a parser
|
39 |
self.parser = JsonOutputParser()
|
@@ -45,12 +44,44 @@ class OpenAIHandler:
|
|
45 |
)
|
46 |
self._set_config()
|
47 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
48 |
def _set_config(self):
|
49 |
-
self.
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
54 |
# Adjusting the LLM settings based on whether Azure is used
|
55 |
if self.is_azure:
|
56 |
self.llm_object.deployment_name = self.model_name
|
@@ -68,43 +99,84 @@ class OpenAIHandler:
|
|
68 |
|
69 |
def _adjust_config(self):
|
70 |
new_temp = self.adjust_temp + self.temp_increment
|
71 |
-
self.json_report
|
|
|
72 |
self.logger.info(f'Incrementing temperature from {self.adjust_temp} to {new_temp}')
|
73 |
self.adjust_temp += self.temp_increment
|
74 |
-
self.
|
75 |
|
76 |
def _reset_config(self):
|
77 |
-
self.json_report
|
|
|
78 |
self.logger.info(f'Resetting temperature from {self.adjust_temp} to {self.starting_temp}')
|
79 |
self.adjust_temp = self.starting_temp
|
80 |
-
self.
|
81 |
|
82 |
def _build_model_chain_parser(self):
|
83 |
if not self.is_azure and ('instruct' in self.name_parts):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
84 |
# Set up the retry parser with 3 retries
|
85 |
self.retry_parser = RetryWithErrorOutputParser.from_llm(
|
86 |
-
|
87 |
-
|
|
|
88 |
)
|
89 |
else:
|
90 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
91 |
self.retry_parser = RetryWithErrorOutputParser.from_llm(
|
92 |
-
|
93 |
-
|
|
|
94 |
)
|
|
|
95 |
# Prepare the chain
|
96 |
-
if
|
97 |
-
|
98 |
-
self.chain = self.prompt | (self.format_input_for_azure if self.is_azure else OpenAI(model=self.model_name))
|
99 |
else:
|
100 |
-
|
101 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
102 |
|
103 |
|
104 |
def call_llm_api_OpenAI(self, prompt_template, json_report, paths):
|
105 |
_____, ____, _, __, ___, json_file_path_wiki, txt_file_path_ind_prompt = paths
|
106 |
self.json_report = json_report
|
107 |
-
self.json_report
|
|
|
108 |
self.monitor.start_monitoring_usage()
|
109 |
nt_in = 0
|
110 |
nt_out = 0
|
@@ -113,14 +185,20 @@ class OpenAIHandler:
|
|
113 |
while ind < self.MAX_RETRIES:
|
114 |
ind += 1
|
115 |
try:
|
116 |
-
|
117 |
# Invoke the chain to generate prompt text
|
118 |
-
response = self.chain.invoke({"query": prompt_template
|
119 |
|
120 |
response_text = response.content if not isinstance(response, str) else response
|
121 |
|
122 |
# Use retry_parser to parse the response with retry logic
|
123 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
124 |
|
125 |
if output is None:
|
126 |
self.logger.error(f'[Attempt {ind}] Failed to extract JSON from:\n{response_text}')
|
@@ -136,14 +214,11 @@ class OpenAIHandler:
|
|
136 |
else:
|
137 |
self.monitor.stop_inference_timer() # Starts tool timer too
|
138 |
|
139 |
-
json_report
|
|
|
140 |
|
141 |
output_WFO, WFO_record, output_GEO, GEO_record = run_tools(output, self.tool_WFO, self.tool_GEO, self.tool_wikipedia, json_file_path_wiki)
|
142 |
|
143 |
-
# output1, WFO_record = validate_taxonomy_WFO(self.tool_WFO, output, replace_if_success_wfo=False)
|
144 |
-
# output2, GEO_record = validate_coordinates_here(self.tool_GEO, output, replace_if_success_geo=False)
|
145 |
-
# validate_wikipedia(self.tool_wikipedia, json_file_path_wiki, output)
|
146 |
-
|
147 |
save_individual_prompt(sanitize_prompt(prompt_template), txt_file_path_ind_prompt)
|
148 |
|
149 |
self.logger.info(f"Formatted JSON:\n{json.dumps(output,indent=4)}")
|
@@ -153,7 +228,8 @@ class OpenAIHandler:
|
|
153 |
if self.adjust_temp != self.starting_temp:
|
154 |
self._reset_config()
|
155 |
|
156 |
-
json_report
|
|
|
157 |
return output, nt_in, nt_out, WFO_record, GEO_record, usage_report
|
158 |
|
159 |
except Exception as e:
|
@@ -163,11 +239,15 @@ class OpenAIHandler:
|
|
163 |
time.sleep(self.RETRY_DELAY)
|
164 |
|
165 |
self.logger.info(f"Failed to extract valid JSON after [{ind}] attempts")
|
166 |
-
self.json_report
|
|
|
167 |
|
168 |
self.monitor.stop_inference_timer() # Starts tool timer too
|
169 |
usage_report = self.monitor.stop_monitoring_report_usage()
|
170 |
self._reset_config()
|
171 |
|
172 |
-
json_report
|
|
|
173 |
return None, nt_in, nt_out, None, None, usage_report
|
|
|
|
|
|
11 |
class OpenAIHandler:
|
12 |
RETRY_DELAY = 10 # Wait 10 seconds before retrying
|
13 |
MAX_RETRIES = 3 # Maximum number of retries
|
14 |
+
STARTING_TEMP = 0.5 # 0.5, config_vals_for_permutation
|
15 |
TOKENIZER_NAME = 'gpt-4'
|
16 |
VENDOR = 'openai'
|
17 |
|
18 |
+
def __init__(self, cfg, logger, model_name, JSON_dict_structure, is_azure, llm_object, config_vals_for_permutation):
|
19 |
self.cfg = cfg
|
20 |
self.tool_WFO = self.cfg['leafmachine']['project']['tool_WFO']
|
21 |
self.tool_GEO = self.cfg['leafmachine']['project']['tool_GEO']
|
|
|
26 |
self.JSON_dict_structure = JSON_dict_structure
|
27 |
self.is_azure = is_azure
|
28 |
self.llm_object = llm_object
|
29 |
+
self.name_parts = self.model_name.lower().split('-')
|
30 |
|
31 |
self.monitor = SystemLoadMonitor(logger)
|
32 |
self.has_GPU = torch.cuda.is_available()
|
33 |
|
34 |
+
### Config
|
35 |
+
self.config_vals_for_permutation = config_vals_for_permutation
|
|
|
36 |
|
37 |
# Set up a parser
|
38 |
self.parser = JsonOutputParser()
|
|
|
44 |
)
|
45 |
self._set_config()
|
46 |
|
47 |
+
def _can_use_json_mode(self):
|
48 |
+
if self.is_azure:
|
49 |
+
return False
|
50 |
+
# gpt-4-turbo-preview (gpt-4-0125-preview)
|
51 |
+
if ('0125' in self.name_parts) and ('4' in self.name_parts):
|
52 |
+
return True
|
53 |
+
# gpt-3.5-turbo-0125
|
54 |
+
elif ('0125' in self.name_parts) and ('3.5' in self.name_parts) and ('turbo' in self.name_parts):
|
55 |
+
return True
|
56 |
+
else:
|
57 |
+
return False
|
58 |
+
|
59 |
+
|
60 |
def _set_config(self):
|
61 |
+
if self.config_vals_for_permutation:
|
62 |
+
self.starting_temp = float(self.config_vals_for_permutation.get('openai').get('temperature'))
|
63 |
+
self.model_kwargs = {
|
64 |
+
'max_tokens': self.config_vals_for_permutation.get('openai').get('max_tokens'),
|
65 |
+
'temperature': self.starting_temp,
|
66 |
+
# 'seed': self.config_vals_for_permutation.get('openai').get('seed'),
|
67 |
+
'top_p': self.config_vals_for_permutation.get('openai').get('top_p'),
|
68 |
+
}
|
69 |
+
else:
|
70 |
+
self.starting_temp = float(self.STARTING_TEMP)
|
71 |
+
self.model_kwargs = {
|
72 |
+
'max_tokens': 1024,
|
73 |
+
'temperature': self.starting_temp,
|
74 |
+
# 'seed': 2023,
|
75 |
+
'top_p': 1, # Set to 1, change temp only
|
76 |
+
}
|
77 |
+
|
78 |
+
### Not all openai models support json mode
|
79 |
+
if self._can_use_json_mode():
|
80 |
+
self.model_kwargs.update({"response_format": {"type": "json_object"}})
|
81 |
+
|
82 |
+
self.temp_increment = float(0.2)
|
83 |
+
self.adjust_temp = self.starting_temp
|
84 |
+
|
85 |
# Adjusting the LLM settings based on whether Azure is used
|
86 |
if self.is_azure:
|
87 |
self.llm_object.deployment_name = self.model_name
|
|
|
99 |
|
100 |
def _adjust_config(self):
|
101 |
new_temp = self.adjust_temp + self.temp_increment
|
102 |
+
if self.json_report:
|
103 |
+
self.json_report.set_text(text_main=f'Incrementing temperature from {self.adjust_temp} to {new_temp}')
|
104 |
self.logger.info(f'Incrementing temperature from {self.adjust_temp} to {new_temp}')
|
105 |
self.adjust_temp += self.temp_increment
|
106 |
+
self.model_kwargs['temperature'] = self.adjust_temp
|
107 |
|
108 |
def _reset_config(self):
|
109 |
+
if self.json_report:
|
110 |
+
self.json_report.set_text(text_main=f'Resetting temperature from {self.adjust_temp} to {self.starting_temp}')
|
111 |
self.logger.info(f'Resetting temperature from {self.adjust_temp} to {self.starting_temp}')
|
112 |
self.adjust_temp = self.starting_temp
|
113 |
+
self.model_kwargs['temperature'] = self.starting_temp
|
114 |
|
115 |
def _build_model_chain_parser(self):
|
116 |
if not self.is_azure and ('instruct' in self.name_parts):
|
117 |
+
# Determine the LLM to use based on whether this is an Azure instance
|
118 |
+
if self.is_azure:
|
119 |
+
llm_to_use = self.llm_object
|
120 |
+
else:
|
121 |
+
llm_to_use = OpenAI(
|
122 |
+
model=self.model_name,
|
123 |
+
temperature=self.model_kwargs.get('temperature'),
|
124 |
+
top_p=self.model_kwargs.get('top_p'),
|
125 |
+
max_tokens=self.model_kwargs.get('max_tokens')
|
126 |
+
)
|
127 |
# Set up the retry parser with 3 retries
|
128 |
self.retry_parser = RetryWithErrorOutputParser.from_llm(
|
129 |
+
parser=self.parser,
|
130 |
+
llm=llm_to_use,
|
131 |
+
max_retries=self.MAX_RETRIES
|
132 |
)
|
133 |
else:
|
134 |
+
# Determine the LLM to use for non-Azure instances
|
135 |
+
if self.is_azure:
|
136 |
+
llm_to_use = self.llm_object
|
137 |
+
self.llm_object.temperature = self.model_kwargs.get('temperature')
|
138 |
+
self.llm_object.max_tokens = self.model_kwargs.get('max_tokens')
|
139 |
+
self.llm_object.model_kwargs = self.model_kwargs
|
140 |
+
else:
|
141 |
+
llm_to_use = ChatOpenAI(
|
142 |
+
model=self.model_name,
|
143 |
+
temperature=self.model_kwargs.get('temperature'),
|
144 |
+
top_p=self.model_kwargs.get('top_p'),
|
145 |
+
max_tokens=self.model_kwargs.get('max_tokens'),
|
146 |
+
)
|
147 |
+
# Set up the retry parser with 3 retries for other cases
|
148 |
self.retry_parser = RetryWithErrorOutputParser.from_llm(
|
149 |
+
parser=self.parser,
|
150 |
+
llm=llm_to_use,
|
151 |
+
max_retries=self.MAX_RETRIES
|
152 |
)
|
153 |
+
|
154 |
# Prepare the chain
|
155 |
+
if self.is_azure:
|
156 |
+
chain_llm_to_use = self.format_input_for_azure
|
|
|
157 |
else:
|
158 |
+
if 'instruct' in self.name_parts:
|
159 |
+
chain_llm_to_use = OpenAI(
|
160 |
+
model=self.model_name,
|
161 |
+
temperature=self.model_kwargs.get('temperature'),
|
162 |
+
top_p=self.model_kwargs.get('top_p'),
|
163 |
+
max_tokens=self.model_kwargs.get('max_tokens')
|
164 |
+
)
|
165 |
+
else:
|
166 |
+
chain_llm_to_use = ChatOpenAI(
|
167 |
+
model=self.model_name,
|
168 |
+
temperature=self.model_kwargs.get('temperature'),
|
169 |
+
top_p=self.model_kwargs.get('top_p'),
|
170 |
+
max_tokens=self.model_kwargs.get('max_tokens')
|
171 |
+
)
|
172 |
+
self.chain = self.prompt | chain_llm_to_use
|
173 |
|
174 |
|
175 |
def call_llm_api_OpenAI(self, prompt_template, json_report, paths):
|
176 |
_____, ____, _, __, ___, json_file_path_wiki, txt_file_path_ind_prompt = paths
|
177 |
self.json_report = json_report
|
178 |
+
if self.json_report:
|
179 |
+
self.json_report.set_text(text_main=f'Sending request to {self.model_name}')
|
180 |
self.monitor.start_monitoring_usage()
|
181 |
nt_in = 0
|
182 |
nt_out = 0
|
|
|
185 |
while ind < self.MAX_RETRIES:
|
186 |
ind += 1
|
187 |
try:
|
188 |
+
self.logger.info(str(self.model_kwargs))
|
189 |
# Invoke the chain to generate prompt text
|
190 |
+
response = self.chain.invoke(input={"query": prompt_template})#, **self.model_kwargs)# "model_kwargs": self.model_kwargs})
|
191 |
|
192 |
response_text = response.content if not isinstance(response, str) else response
|
193 |
|
194 |
# Use retry_parser to parse the response with retry logic
|
195 |
+
try:
|
196 |
+
output = self.retry_parser.parse_with_prompt(response_text, prompt_value=prompt_template)
|
197 |
+
except:
|
198 |
+
try:
|
199 |
+
output = json.loads(response_text)
|
200 |
+
except:
|
201 |
+
output = None
|
202 |
|
203 |
if output is None:
|
204 |
self.logger.error(f'[Attempt {ind}] Failed to extract JSON from:\n{response_text}')
|
|
|
214 |
else:
|
215 |
self.monitor.stop_inference_timer() # Starts tool timer too
|
216 |
|
217 |
+
if self.json_report:
|
218 |
+
self.json_report.set_text(text_main=f'Working on WFO, Geolocation, Links')
|
219 |
|
220 |
output_WFO, WFO_record, output_GEO, GEO_record = run_tools(output, self.tool_WFO, self.tool_GEO, self.tool_wikipedia, json_file_path_wiki)
|
221 |
|
|
|
|
|
|
|
|
|
222 |
save_individual_prompt(sanitize_prompt(prompt_template), txt_file_path_ind_prompt)
|
223 |
|
224 |
self.logger.info(f"Formatted JSON:\n{json.dumps(output,indent=4)}")
|
|
|
228 |
if self.adjust_temp != self.starting_temp:
|
229 |
self._reset_config()
|
230 |
|
231 |
+
if self.json_report:
|
232 |
+
self.json_report.set_text(text_main=f'LLM call successful')
|
233 |
return output, nt_in, nt_out, WFO_record, GEO_record, usage_report
|
234 |
|
235 |
except Exception as e:
|
|
|
239 |
time.sleep(self.RETRY_DELAY)
|
240 |
|
241 |
self.logger.info(f"Failed to extract valid JSON after [{ind}] attempts")
|
242 |
+
if self.json_report:
|
243 |
+
self.json_report.set_text(text_main=f'Failed to extract valid JSON after [{ind}] attempts')
|
244 |
|
245 |
self.monitor.stop_inference_timer() # Starts tool timer too
|
246 |
usage_report = self.monitor.stop_monitoring_report_usage()
|
247 |
self._reset_config()
|
248 |
|
249 |
+
if self.json_report:
|
250 |
+
self.json_report.set_text(text_main=f'LLM call failed')
|
251 |
return None, nt_in, nt_out, None, None, usage_report
|
252 |
+
|
253 |
+
|
vouchervision/LLM_local_MistralAI.py
CHANGED
@@ -22,7 +22,7 @@ class LocalMistralHandler:
|
|
22 |
VENDOR = 'mistral'
|
23 |
MAX_GPU_MONITORING_INTERVAL = 2 # seconds
|
24 |
|
25 |
-
def __init__(self, cfg, logger, model_name, JSON_dict_structure):
|
26 |
self.cfg = cfg
|
27 |
self.tool_WFO = self.cfg['leafmachine']['project']['tool_WFO']
|
28 |
self.tool_GEO = self.cfg['leafmachine']['project']['tool_GEO']
|
@@ -122,13 +122,15 @@ class LocalMistralHandler:
|
|
122 |
|
123 |
def _adjust_config(self):
|
124 |
new_temp = self.adjust_temp + self.temp_increment
|
125 |
-
self.json_report
|
|
|
126 |
self.logger.info(f'Incrementing temperature from {self.adjust_temp} to {new_temp}')
|
127 |
self.adjust_temp += self.temp_increment
|
128 |
|
129 |
|
130 |
def _reset_config(self):
|
131 |
-
self.json_report
|
|
|
132 |
self.logger.info(f'Resetting temperature from {self.adjust_temp} to {self.starting_temp}')
|
133 |
self.adjust_temp = self.starting_temp
|
134 |
|
@@ -153,7 +155,8 @@ class LocalMistralHandler:
|
|
153 |
def call_llm_local_MistralAI(self, prompt_template, json_report, paths):
|
154 |
_____, ____, _, __, ___, json_file_path_wiki, txt_file_path_ind_prompt = paths
|
155 |
self.json_report = json_report
|
156 |
-
self.json_report
|
|
|
157 |
self.monitor.start_monitoring_usage()
|
158 |
|
159 |
nt_in = 0
|
@@ -188,8 +191,9 @@ class LocalMistralHandler:
|
|
188 |
self._adjust_config()
|
189 |
else:
|
190 |
self.monitor.stop_inference_timer() # Starts tool timer too
|
191 |
-
|
192 |
-
json_report
|
|
|
193 |
output_WFO, WFO_record, output_GEO, GEO_record = run_tools(output, self.tool_WFO, self.tool_GEO, self.tool_wikipedia, json_file_path_wiki)
|
194 |
|
195 |
save_individual_prompt(sanitize_prompt(prompt_template), txt_file_path_ind_prompt)
|
@@ -201,7 +205,8 @@ class LocalMistralHandler:
|
|
201 |
if self.adjust_temp != self.starting_temp:
|
202 |
self._reset_config()
|
203 |
|
204 |
-
json_report
|
|
|
205 |
del results
|
206 |
return output, nt_in, nt_out, WFO_record, GEO_record, usage_report
|
207 |
|
@@ -210,11 +215,13 @@ class LocalMistralHandler:
|
|
210 |
self._adjust_config()
|
211 |
|
212 |
self.logger.info(f"Failed to extract valid JSON after [{ind}] attempts")
|
213 |
-
self.json_report
|
|
|
214 |
|
215 |
self.monitor.stop_inference_timer() # Starts tool timer too
|
216 |
usage_report = self.monitor.stop_monitoring_report_usage()
|
217 |
-
json_report
|
|
|
218 |
|
219 |
self._reset_config()
|
220 |
return None, nt_in, nt_out, None, None, usage_report
|
|
|
22 |
VENDOR = 'mistral'
|
23 |
MAX_GPU_MONITORING_INTERVAL = 2 # seconds
|
24 |
|
25 |
+
def __init__(self, cfg, logger, model_name, JSON_dict_structure, config_vals_for_permutation):
|
26 |
self.cfg = cfg
|
27 |
self.tool_WFO = self.cfg['leafmachine']['project']['tool_WFO']
|
28 |
self.tool_GEO = self.cfg['leafmachine']['project']['tool_GEO']
|
|
|
122 |
|
123 |
def _adjust_config(self):
|
124 |
new_temp = self.adjust_temp + self.temp_increment
|
125 |
+
if self.json_report:
|
126 |
+
self.json_report.set_text(text_main=f'Incrementing temperature from {self.adjust_temp} to {new_temp}')
|
127 |
self.logger.info(f'Incrementing temperature from {self.adjust_temp} to {new_temp}')
|
128 |
self.adjust_temp += self.temp_increment
|
129 |
|
130 |
|
131 |
def _reset_config(self):
|
132 |
+
if self.json_report:
|
133 |
+
self.json_report.set_text(text_main=f'Resetting temperature from {self.adjust_temp} to {self.starting_temp}')
|
134 |
self.logger.info(f'Resetting temperature from {self.adjust_temp} to {self.starting_temp}')
|
135 |
self.adjust_temp = self.starting_temp
|
136 |
|
|
|
155 |
def call_llm_local_MistralAI(self, prompt_template, json_report, paths):
|
156 |
_____, ____, _, __, ___, json_file_path_wiki, txt_file_path_ind_prompt = paths
|
157 |
self.json_report = json_report
|
158 |
+
if self.json_report:
|
159 |
+
self.json_report.set_text(text_main=f'Sending request to {self.model_name}')
|
160 |
self.monitor.start_monitoring_usage()
|
161 |
|
162 |
nt_in = 0
|
|
|
191 |
self._adjust_config()
|
192 |
else:
|
193 |
self.monitor.stop_inference_timer() # Starts tool timer too
|
194 |
+
|
195 |
+
if self.json_report:
|
196 |
+
self.json_report.set_text(text_main=f'Working on WFO, Geolocation, Links')
|
197 |
output_WFO, WFO_record, output_GEO, GEO_record = run_tools(output, self.tool_WFO, self.tool_GEO, self.tool_wikipedia, json_file_path_wiki)
|
198 |
|
199 |
save_individual_prompt(sanitize_prompt(prompt_template), txt_file_path_ind_prompt)
|
|
|
205 |
if self.adjust_temp != self.starting_temp:
|
206 |
self._reset_config()
|
207 |
|
208 |
+
if self.json_report:
|
209 |
+
self.json_report.set_text(text_main=f'LLM call successful')
|
210 |
del results
|
211 |
return output, nt_in, nt_out, WFO_record, GEO_record, usage_report
|
212 |
|
|
|
215 |
self._adjust_config()
|
216 |
|
217 |
self.logger.info(f"Failed to extract valid JSON after [{ind}] attempts")
|
218 |
+
if self.json_report:
|
219 |
+
self.json_report.set_text(text_main=f'Failed to extract valid JSON after [{ind}] attempts')
|
220 |
|
221 |
self.monitor.stop_inference_timer() # Starts tool timer too
|
222 |
usage_report = self.monitor.stop_monitoring_report_usage()
|
223 |
+
if self.json_report:
|
224 |
+
self.json_report.set_text(text_main=f'LLM call failed')
|
225 |
|
226 |
self._reset_config()
|
227 |
return None, nt_in, nt_out, None, None, usage_report
|
vouchervision/LLM_local_cpu_MistralAI.py
CHANGED
@@ -30,7 +30,7 @@ class LocalCPUMistralHandler:
|
|
30 |
SEED = 2023
|
31 |
|
32 |
|
33 |
-
def __init__(self, cfg, logger, model_name, JSON_dict_structure):
|
34 |
self.cfg = cfg
|
35 |
self.tool_WFO = self.cfg['leafmachine']['project']['tool_WFO']
|
36 |
self.tool_GEO = self.cfg['leafmachine']['project']['tool_GEO']
|
@@ -106,13 +106,15 @@ class LocalCPUMistralHandler:
|
|
106 |
|
107 |
def _adjust_config(self):
|
108 |
new_temp = self.adjust_temp + self.temp_increment
|
109 |
-
self.json_report
|
|
|
110 |
self.logger.info(f'Incrementing temperature from {self.adjust_temp} to {new_temp}')
|
111 |
self.adjust_temp += self.temp_increment
|
112 |
self.config['temperature'] = self.adjust_temp
|
113 |
|
114 |
def _reset_config(self):
|
115 |
-
self.json_report
|
|
|
116 |
self.logger.info(f'Resetting temperature from {self.adjust_temp} to {self.starting_temp}')
|
117 |
self.adjust_temp = self.starting_temp
|
118 |
self.config['temperature'] = self.starting_temp
|
@@ -140,7 +142,8 @@ class LocalCPUMistralHandler:
|
|
140 |
def call_llm_local_cpu_MistralAI(self, prompt_template, json_report, paths):
|
141 |
_____, ____, _, __, ___, json_file_path_wiki, txt_file_path_ind_prompt = paths
|
142 |
self.json_report = json_report
|
143 |
-
self.json_report
|
|
|
144 |
self.monitor.start_monitoring_usage()
|
145 |
|
146 |
nt_in = 0
|
@@ -180,7 +183,8 @@ class LocalCPUMistralHandler:
|
|
180 |
else:
|
181 |
self.monitor.stop_inference_timer() # Starts tool timer too
|
182 |
|
183 |
-
json_report
|
|
|
184 |
output_WFO, WFO_record, output_GEO, GEO_record = run_tools(output, self.tool_WFO, self.tool_GEO, self.tool_wikipedia, json_file_path_wiki)
|
185 |
|
186 |
save_individual_prompt(sanitize_prompt(prompt_template), txt_file_path_ind_prompt)
|
@@ -192,7 +196,8 @@ class LocalCPUMistralHandler:
|
|
192 |
if self.adjust_temp != self.starting_temp:
|
193 |
self._reset_config()
|
194 |
|
195 |
-
json_report
|
|
|
196 |
return output, nt_in, nt_out, WFO_record, GEO_record, usage_report
|
197 |
|
198 |
except Exception as e:
|
@@ -200,13 +205,15 @@ class LocalCPUMistralHandler:
|
|
200 |
self._adjust_config()
|
201 |
|
202 |
self.logger.info(f"Failed to extract valid JSON after [{ind}] attempts")
|
203 |
-
self.json_report
|
|
|
204 |
|
205 |
self.monitor.stop_inference_timer() # Starts tool timer too
|
206 |
usage_report = self.monitor.stop_monitoring_report_usage()
|
207 |
self._reset_config()
|
208 |
|
209 |
-
json_report
|
|
|
210 |
return None, nt_in, nt_out, None, None, usage_report
|
211 |
|
212 |
|
|
|
30 |
SEED = 2023
|
31 |
|
32 |
|
33 |
+
def __init__(self, cfg, logger, model_name, JSON_dict_structure, config_vals_for_permutation):
|
34 |
self.cfg = cfg
|
35 |
self.tool_WFO = self.cfg['leafmachine']['project']['tool_WFO']
|
36 |
self.tool_GEO = self.cfg['leafmachine']['project']['tool_GEO']
|
|
|
106 |
|
107 |
def _adjust_config(self):
|
108 |
new_temp = self.adjust_temp + self.temp_increment
|
109 |
+
if self.json_report:
|
110 |
+
self.json_report.set_text(text_main=f'Incrementing temperature from {self.adjust_temp} to {new_temp}')
|
111 |
self.logger.info(f'Incrementing temperature from {self.adjust_temp} to {new_temp}')
|
112 |
self.adjust_temp += self.temp_increment
|
113 |
self.config['temperature'] = self.adjust_temp
|
114 |
|
115 |
def _reset_config(self):
|
116 |
+
if self.json_report:
|
117 |
+
self.json_report.set_text(text_main=f'Resetting temperature from {self.adjust_temp} to {self.starting_temp}')
|
118 |
self.logger.info(f'Resetting temperature from {self.adjust_temp} to {self.starting_temp}')
|
119 |
self.adjust_temp = self.starting_temp
|
120 |
self.config['temperature'] = self.starting_temp
|
|
|
142 |
def call_llm_local_cpu_MistralAI(self, prompt_template, json_report, paths):
|
143 |
_____, ____, _, __, ___, json_file_path_wiki, txt_file_path_ind_prompt = paths
|
144 |
self.json_report = json_report
|
145 |
+
if self.json_report:
|
146 |
+
self.json_report.set_text(text_main=f'Sending request to {self.model_name}')
|
147 |
self.monitor.start_monitoring_usage()
|
148 |
|
149 |
nt_in = 0
|
|
|
183 |
else:
|
184 |
self.monitor.stop_inference_timer() # Starts tool timer too
|
185 |
|
186 |
+
if self.json_report:
|
187 |
+
self.json_report.set_text(text_main=f'Working on WFO, Geolocation, Links')
|
188 |
output_WFO, WFO_record, output_GEO, GEO_record = run_tools(output, self.tool_WFO, self.tool_GEO, self.tool_wikipedia, json_file_path_wiki)
|
189 |
|
190 |
save_individual_prompt(sanitize_prompt(prompt_template), txt_file_path_ind_prompt)
|
|
|
196 |
if self.adjust_temp != self.starting_temp:
|
197 |
self._reset_config()
|
198 |
|
199 |
+
if self.json_report:
|
200 |
+
self.json_report.set_text(text_main=f'LLM call successful')
|
201 |
return output, nt_in, nt_out, WFO_record, GEO_record, usage_report
|
202 |
|
203 |
except Exception as e:
|
|
|
205 |
self._adjust_config()
|
206 |
|
207 |
self.logger.info(f"Failed to extract valid JSON after [{ind}] attempts")
|
208 |
+
if self.json_report:
|
209 |
+
self.json_report.set_text(text_main=f'Failed to extract valid JSON after [{ind}] attempts')
|
210 |
|
211 |
self.monitor.stop_inference_timer() # Starts tool timer too
|
212 |
usage_report = self.monitor.stop_monitoring_report_usage()
|
213 |
self._reset_config()
|
214 |
|
215 |
+
if self.json_report:
|
216 |
+
self.json_report.set_text(text_main=f'LLM call failed')
|
217 |
return None, nt_in, nt_out, None, None, usage_report
|
218 |
|
219 |
|
vouchervision/LM2_logger.py
CHANGED
@@ -1,4 +1,5 @@
|
|
1 |
import logging, os, psutil, torch, platform, cpuinfo, yaml #py-cpuinfo
|
|
|
2 |
from vouchervision.general_utils import get_datetime, print_main_warn, print_main_info
|
3 |
|
4 |
class SanitizingFileHandler(logging.FileHandler):
|
@@ -17,7 +18,7 @@ def start_logging(Dirs, cfg):
|
|
17 |
path_log = os.path.join(Dirs.path_log, '__'.join(['LM2-log', str(get_datetime()), run_name]) + '.log')
|
18 |
|
19 |
# Disable default StreamHandler
|
20 |
-
logging.getLogger().handlers = []
|
21 |
|
22 |
# create logger
|
23 |
logger = logging.getLogger('Hardware Components')
|
@@ -27,20 +28,25 @@ def start_logging(Dirs, cfg):
|
|
27 |
sanitizing_fh = SanitizingFileHandler(path_log, encoding='utf-8')
|
28 |
sanitizing_fh.setLevel(logging.DEBUG)
|
29 |
|
|
|
|
|
|
|
30 |
# create console handler and set level to debug
|
31 |
-
ch = logging.StreamHandler()
|
32 |
-
ch.setLevel(logging.DEBUG)
|
33 |
|
34 |
# create formatter
|
35 |
formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(name)s - %(message)s')
|
36 |
|
37 |
# add formatter to handlers
|
38 |
sanitizing_fh.setFormatter(formatter)
|
39 |
-
|
|
|
40 |
|
41 |
# add handlers to logger
|
42 |
logger.addHandler(sanitizing_fh)
|
43 |
-
logger.addHandler(
|
|
|
44 |
|
45 |
# Create a logger for the file handler
|
46 |
file_logger = logging.getLogger('file_logger')
|
@@ -110,6 +116,17 @@ def find_cpu_info():
|
|
110 |
except:
|
111 |
return "CPU: UNKNOWN"
|
112 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
113 |
|
114 |
def LM2_banner():
|
115 |
logo = """
|
|
|
1 |
import logging, os, psutil, torch, platform, cpuinfo, yaml #py-cpuinfo
|
2 |
+
from tqdm import tqdm
|
3 |
from vouchervision.general_utils import get_datetime, print_main_warn, print_main_info
|
4 |
|
5 |
class SanitizingFileHandler(logging.FileHandler):
|
|
|
18 |
path_log = os.path.join(Dirs.path_log, '__'.join(['LM2-log', str(get_datetime()), run_name]) + '.log')
|
19 |
|
20 |
# Disable default StreamHandler
|
21 |
+
logging.getLogger().handlers = []
|
22 |
|
23 |
# create logger
|
24 |
logger = logging.getLogger('Hardware Components')
|
|
|
28 |
sanitizing_fh = SanitizingFileHandler(path_log, encoding='utf-8')
|
29 |
sanitizing_fh.setLevel(logging.DEBUG)
|
30 |
|
31 |
+
tqdm_handler = TqdmLoggingHandler()
|
32 |
+
tqdm_handler.setLevel(logging.DEBUG)
|
33 |
+
|
34 |
# create console handler and set level to debug
|
35 |
+
# ch = logging.StreamHandler()
|
36 |
+
# ch.setLevel(logging.DEBUG)
|
37 |
|
38 |
# create formatter
|
39 |
formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(name)s - %(message)s')
|
40 |
|
41 |
# add formatter to handlers
|
42 |
sanitizing_fh.setFormatter(formatter)
|
43 |
+
tqdm_handler.setFormatter(formatter)
|
44 |
+
# ch.setFormatter(formatter)
|
45 |
|
46 |
# add handlers to logger
|
47 |
logger.addHandler(sanitizing_fh)
|
48 |
+
logger.addHandler(tqdm_handler)
|
49 |
+
# logger.addHandler(ch)
|
50 |
|
51 |
# Create a logger for the file handler
|
52 |
file_logger = logging.getLogger('file_logger')
|
|
|
116 |
except:
|
117 |
return "CPU: UNKNOWN"
|
118 |
|
119 |
+
class TqdmLoggingHandler(logging.Handler):
|
120 |
+
def __init__(self, level=logging.NOTSET):
|
121 |
+
super().__init__(level)
|
122 |
+
|
123 |
+
def emit(self, record):
|
124 |
+
try:
|
125 |
+
msg = self.format(record)
|
126 |
+
tqdm.write(msg) # Use tqdm's write function to ensure correct output
|
127 |
+
self.flush()
|
128 |
+
except Exception:
|
129 |
+
self.handleError(record)
|
130 |
|
131 |
def LM2_banner():
|
132 |
logo = """
|
vouchervision/OCR_google_cloud_vision.py
CHANGED
@@ -123,8 +123,9 @@ class OCREngine:
|
|
123 |
|
124 |
self.model_path = "liuhaotian/" + self.cfg['leafmachine']['project']['OCR_option_llava']
|
125 |
self.model_quant = self.cfg['leafmachine']['project']['OCR_option_llava_bit']
|
126 |
-
|
127 |
-
self.json_report
|
|
|
128 |
|
129 |
if self.model_quant == '4bit':
|
130 |
use_4bit = True
|
@@ -191,7 +192,8 @@ class OCREngine:
|
|
191 |
# Process each detected text region
|
192 |
for box in self.prediction_result["boxes"]:
|
193 |
i+=1
|
194 |
-
self.json_report
|
|
|
195 |
|
196 |
vertices = [{"x": int(vertex[0]), "y": int(vertex[1])} for vertex in box]
|
197 |
|
@@ -283,7 +285,8 @@ class OCREngine:
|
|
283 |
i=0
|
284 |
for bound in tqdm(available_bounds, desc="Processing words using Google Vision bboxes"):
|
285 |
i+=1
|
286 |
-
self.json_report
|
|
|
287 |
|
288 |
vertices = bound["vertices"]
|
289 |
|
@@ -688,7 +691,8 @@ class OCREngine:
|
|
688 |
# logger.info(f"CRAFT trOCR:\n{self.OCR}")
|
689 |
|
690 |
if 'LLaVA' in self.OCR_option: # This option does not produce an OCR helper image
|
691 |
-
self.json_report
|
|
|
692 |
|
693 |
image, json_output, direct_output, str_output, usage_report = self.Llava.transcribe_image(self.path, self.multimodal_prompt)
|
694 |
self.logger.info(f"LLaVA Usage Report for Model {self.Llava.model_path}:\n{usage_report}")
|
@@ -786,4 +790,20 @@ class OCREngine:
|
|
786 |
from craft_text_detector import empty_cuda_cache
|
787 |
empty_cuda_cache()
|
788 |
except:
|
789 |
-
pass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
123 |
|
124 |
self.model_path = "liuhaotian/" + self.cfg['leafmachine']['project']['OCR_option_llava']
|
125 |
self.model_quant = self.cfg['leafmachine']['project']['OCR_option_llava_bit']
|
126 |
+
|
127 |
+
if self.json_report:
|
128 |
+
self.json_report.set_text(text_main=f'Loading LLaVA model: {self.model_path} Quantization: {self.model_quant}')
|
129 |
|
130 |
if self.model_quant == '4bit':
|
131 |
use_4bit = True
|
|
|
192 |
# Process each detected text region
|
193 |
for box in self.prediction_result["boxes"]:
|
194 |
i+=1
|
195 |
+
if self.json_report:
|
196 |
+
self.json_report.set_text(text_main=f'Locating text using CRAFT --- {i}/{total_b}')
|
197 |
|
198 |
vertices = [{"x": int(vertex[0]), "y": int(vertex[1])} for vertex in box]
|
199 |
|
|
|
285 |
i=0
|
286 |
for bound in tqdm(available_bounds, desc="Processing words using Google Vision bboxes"):
|
287 |
i+=1
|
288 |
+
if self.json_report:
|
289 |
+
self.json_report.set_text(text_main=f'Working on trOCR :construction: {i}/{total_b}')
|
290 |
|
291 |
vertices = bound["vertices"]
|
292 |
|
|
|
691 |
# logger.info(f"CRAFT trOCR:\n{self.OCR}")
|
692 |
|
693 |
if 'LLaVA' in self.OCR_option: # This option does not produce an OCR helper image
|
694 |
+
if self.json_report:
|
695 |
+
self.json_report.set_text(text_main=f'Working on LLaVA {self.Llava.model_path} transcription :construction:')
|
696 |
|
697 |
image, json_output, direct_output, str_output, usage_report = self.Llava.transcribe_image(self.path, self.multimodal_prompt)
|
698 |
self.logger.info(f"LLaVA Usage Report for Model {self.Llava.model_path}:\n{usage_report}")
|
|
|
790 |
from craft_text_detector import empty_cuda_cache
|
791 |
empty_cuda_cache()
|
792 |
except:
|
793 |
+
pass
|
794 |
+
|
795 |
+
def check_for_inappropriate_content(file_stream):
|
796 |
+
client = vision.ImageAnnotatorClient()
|
797 |
+
|
798 |
+
content = file_stream.read()
|
799 |
+
image = vision.Image(content=content)
|
800 |
+
response = client.safe_search_detection(image=image)
|
801 |
+
safe = response.safe_search_annotation
|
802 |
+
|
803 |
+
# Check the levels of adult, violence, racy, etc. content.
|
804 |
+
if (safe.adult > vision.Likelihood.POSSIBLE or
|
805 |
+
safe.violence > vision.Likelihood.POSSIBLE or
|
806 |
+
safe.racy > vision.Likelihood.POSSIBLE):
|
807 |
+
return True # The image violates safe search guidelines.
|
808 |
+
|
809 |
+
return False # The image is considered safe.
|
vouchervision/VoucherVision_Config_Builder.py
CHANGED
@@ -49,7 +49,7 @@ def build_VV_config(loaded_cfg=None):
|
|
49 |
|
50 |
check_for_illegal_filenames = False
|
51 |
|
52 |
-
LLM_version_user = 'Azure GPT 3.5
|
53 |
prompt_version = 'SLTPvA_long.yaml' # from ["Version 1", "Version 1 No Domain Knowledge", "Version 2"]
|
54 |
use_LeafMachine2_collage_images = True # Use LeafMachine2 collage images
|
55 |
do_create_OCR_helper_image = True
|
|
|
49 |
|
50 |
check_for_illegal_filenames = False
|
51 |
|
52 |
+
LLM_version_user = 'Azure GPT 3.5 Turbo' #'Azure GPT 4 Turbo 1106-preview'
|
53 |
prompt_version = 'SLTPvA_long.yaml' # from ["Version 1", "Version 1 No Domain Knowledge", "Version 2"]
|
54 |
use_LeafMachine2_collage_images = True # Use LeafMachine2 collage images
|
55 |
do_create_OCR_helper_image = True
|
vouchervision/model_maps.py
CHANGED
@@ -20,9 +20,11 @@ class ModelMaps:
|
|
20 |
'AZURE_GPT_3_5_INSTRUCT': '#9400D3', # Dark Violet
|
21 |
'AZURE_GPT_3_5': '#9932CC', # Dark Orchid
|
22 |
|
23 |
-
'
|
24 |
-
'
|
|
|
25 |
'MISTRAL_MEDIUM': '#FF4500', # Orange Red
|
|
|
26 |
|
27 |
'LOCAL_MIXTRAL_8X7B_INSTRUCT_V01': '#000000', # Black
|
28 |
'LOCAL_MISTRAL_7B_INSTRUCT_V02': '#4a4a4a', # Gray
|
@@ -34,14 +36,14 @@ class ModelMaps:
|
|
34 |
"GPT 4 32k",
|
35 |
"GPT 4 Turbo 0125-preview",
|
36 |
"GPT 4 Turbo 1106-preview",
|
37 |
-
"GPT 3.5",
|
38 |
"GPT 3.5 Instruct",
|
39 |
|
40 |
"Azure GPT 4",
|
41 |
"Azure GPT 4 32k",
|
42 |
"Azure GPT 4 Turbo 0125-preview",
|
43 |
"Azure GPT 4 Turbo 1106-preview",
|
44 |
-
"Azure GPT 3.5",
|
45 |
"Azure GPT 3.5 Instruct",]
|
46 |
|
47 |
MODELS_GOOGLE = ["PaLM 2 text-bison@001",
|
@@ -49,15 +51,18 @@ class ModelMaps:
|
|
49 |
"PaLM 2 text-unicorn@001",
|
50 |
"Gemini Pro"]
|
51 |
|
52 |
-
MODELS_MISTRAL = ["Mistral
|
53 |
-
"Mistral
|
54 |
-
"Mistral
|
|
|
|
|
|
|
55 |
|
56 |
MODELS_LOCAL = ["LOCAL Mixtral 8x7B Instruct v0.1",
|
57 |
"LOCAL Mistral 7B Instruct v0.2",
|
58 |
"LOCAL CPU Mistral 7B Instruct v0.2 GGUF",]
|
59 |
|
60 |
-
MODELS_GUI_DEFAULT = "Azure GPT 3.5
|
61 |
|
62 |
version_mapping_cost = {
|
63 |
'GPT 4 32k': 'GPT_4_32K',
|
@@ -65,23 +70,25 @@ class ModelMaps:
|
|
65 |
'GPT 4 Turbo 0125-preview': 'GPT_4_TURBO_0125',
|
66 |
'GPT 4 Turbo 1106-preview': 'GPT_4_TURBO_1106',
|
67 |
'GPT 3.5 Instruct': 'GPT_3_5_INSTRUCT',
|
68 |
-
'GPT 3.5': 'GPT_3_5',
|
69 |
|
70 |
'Azure GPT 4 32k': 'AZURE_GPT_4_32K',
|
71 |
'Azure GPT 4': 'AZURE_GPT_4',
|
72 |
'Azure GPT 4 Turbo 0125-preview': 'AZURE_GPT_4_TURBO_0125',
|
73 |
'Azure GPT 4 Turbo 1106-preview': 'AZURE_GPT_4_TURBO_1106',
|
74 |
'Azure GPT 3.5 Instruct': 'AZURE_GPT_3_5_INSTRUCT',
|
75 |
-
'Azure GPT 3.5': 'AZURE_GPT_3_5',
|
76 |
|
77 |
'Gemini Pro': 'GEMINI_PRO',
|
78 |
'PaLM 2 text-unicorn@001': 'PALM2_TU_1',
|
79 |
'PaLM 2 text-bison@001': 'PALM2_TB_1',
|
80 |
'PaLM 2 text-bison@002': 'PALM2_TB_2',
|
81 |
|
|
|
82 |
'Mistral Medium': 'MISTRAL_MEDIUM',
|
83 |
'Mistral Small': 'MISTRAL_SMALL',
|
84 |
-
'
|
|
|
85 |
|
86 |
'LOCAL Mixtral 8x7B Instruct v0.1': 'LOCAL_MIXTRAL_8X7B_INSTRUCT_V01',
|
87 |
'LOCAL Mistral 7B Instruct v0.2': 'LOCAL_MISTRAL_7B_INSTRUCT_V02',
|
@@ -97,10 +104,10 @@ class ModelMaps:
|
|
97 |
'GPT 4 Turbo 0125-preview': has_key_openai,
|
98 |
'GPT 4': has_key_openai,
|
99 |
'GPT 4 32k': has_key_openai,
|
100 |
-
'GPT 3.5': has_key_openai,
|
101 |
'GPT 3.5 Instruct': has_key_openai,
|
102 |
|
103 |
-
'Azure GPT 3.5': has_key_azure_openai,
|
104 |
'Azure GPT 3.5 Instruct': has_key_azure_openai,
|
105 |
'Azure GPT 4': has_key_azure_openai,
|
106 |
'Azure GPT 4 Turbo 1106-preview': has_key_azure_openai,
|
@@ -112,9 +119,11 @@ class ModelMaps:
|
|
112 |
'PaLM 2 text-unicorn@001': has_key_google_application_credentials,
|
113 |
'Gemini Pro': has_key_google_application_credentials,
|
114 |
|
115 |
-
'Mistral Tiny': has_key_mistral,
|
116 |
'Mistral Small': has_key_mistral,
|
117 |
'Mistral Medium': has_key_mistral,
|
|
|
|
|
|
|
118 |
|
119 |
'LOCAL Mixtral 8x7B Instruct v0.1': True,
|
120 |
'LOCAL Mistral 7B Instruct v0.2': True,
|
@@ -127,15 +136,17 @@ class ModelMaps:
|
|
127 |
def get_version_mapping_is_azure(cls, key):
|
128 |
version_mapping_is_azure = {
|
129 |
"GPT 4 Turbo 1106-preview": False,
|
|
|
130 |
'GPT 4': False,
|
131 |
'GPT 4 32k': False,
|
132 |
-
'GPT 3.5': False,
|
133 |
'GPT 3.5 Instruct': False,
|
134 |
|
135 |
-
'Azure GPT 3.5': True,
|
136 |
'Azure GPT 3.5 Instruct': True,
|
137 |
'Azure GPT 4': True,
|
138 |
'Azure GPT 4 Turbo 1106-preview': True,
|
|
|
139 |
'Azure GPT 4 32k': True,
|
140 |
|
141 |
'PaLM 2 text-bison@001': False,
|
@@ -143,9 +154,11 @@ class ModelMaps:
|
|
143 |
'PaLM 2 text-unicorn@001': False,
|
144 |
'Gemini Pro': False,
|
145 |
|
146 |
-
'Mistral Tiny': False,
|
147 |
'Mistral Small': False,
|
148 |
'Mistral Medium': False,
|
|
|
|
|
|
|
149 |
|
150 |
'LOCAL Mixtral 8x7B Instruct v0.1': False,
|
151 |
'LOCAL Mistral 7B Instruct v0.2': False,
|
@@ -159,7 +172,7 @@ class ModelMaps:
|
|
159 |
|
160 |
### OpenAI
|
161 |
if key == 'GPT_3_5':
|
162 |
-
return 'gpt-3.5-turbo-1106'
|
163 |
|
164 |
elif key == 'GPT_3_5_INSTRUCT':
|
165 |
return 'gpt-3.5-turbo-instruct'
|
@@ -178,7 +191,7 @@ class ModelMaps:
|
|
178 |
|
179 |
### Azure
|
180 |
elif key == 'AZURE_GPT_3_5':
|
181 |
-
return 'gpt-35-turbo-
|
182 |
|
183 |
elif key == 'AZURE_GPT_3_5_INSTRUCT':
|
184 |
return 'gpt-35-turbo-instruct'
|
@@ -209,14 +222,20 @@ class ModelMaps:
|
|
209 |
return "gemini-1.0-pro"
|
210 |
|
211 |
### Mistral
|
212 |
-
elif key == '
|
213 |
-
return "mistral-
|
|
|
|
|
|
|
214 |
|
215 |
elif key == 'MISTRAL_SMALL':
|
216 |
-
return 'mistral-small'
|
217 |
|
218 |
elif key == 'MISTRAL_MEDIUM':
|
219 |
-
return 'mistral-medium'
|
|
|
|
|
|
|
220 |
|
221 |
|
222 |
### Mistral LOCAL
|
|
|
20 |
'AZURE_GPT_3_5_INSTRUCT': '#9400D3', # Dark Violet
|
21 |
'AZURE_GPT_3_5': '#9932CC', # Dark Orchid
|
22 |
|
23 |
+
'OPEN_MISTRAL_7B': '#FFA07A', # Light Salmon
|
24 |
+
'OPEN_MIXTRAL_8X7B': '#FF8C00', # Dark Orange
|
25 |
+
'MISTRAL_SMALL': '#FF6347', # Tomato
|
26 |
'MISTRAL_MEDIUM': '#FF4500', # Orange Red
|
27 |
+
'MISTRAL_LARGE': '#800000', # Maroon
|
28 |
|
29 |
'LOCAL_MIXTRAL_8X7B_INSTRUCT_V01': '#000000', # Black
|
30 |
'LOCAL_MISTRAL_7B_INSTRUCT_V02': '#4a4a4a', # Gray
|
|
|
36 |
"GPT 4 32k",
|
37 |
"GPT 4 Turbo 0125-preview",
|
38 |
"GPT 4 Turbo 1106-preview",
|
39 |
+
"GPT 3.5 Turbo",
|
40 |
"GPT 3.5 Instruct",
|
41 |
|
42 |
"Azure GPT 4",
|
43 |
"Azure GPT 4 32k",
|
44 |
"Azure GPT 4 Turbo 0125-preview",
|
45 |
"Azure GPT 4 Turbo 1106-preview",
|
46 |
+
"Azure GPT 3.5 Turbo",
|
47 |
"Azure GPT 3.5 Instruct",]
|
48 |
|
49 |
MODELS_GOOGLE = ["PaLM 2 text-bison@001",
|
|
|
51 |
"PaLM 2 text-unicorn@001",
|
52 |
"Gemini Pro"]
|
53 |
|
54 |
+
MODELS_MISTRAL = ["Mistral Small",
|
55 |
+
"Mistral Medium",
|
56 |
+
"Mistral Large",
|
57 |
+
"Open Mixtral 8x7B",
|
58 |
+
"Open Mistral 7B",
|
59 |
+
]
|
60 |
|
61 |
MODELS_LOCAL = ["LOCAL Mixtral 8x7B Instruct v0.1",
|
62 |
"LOCAL Mistral 7B Instruct v0.2",
|
63 |
"LOCAL CPU Mistral 7B Instruct v0.2 GGUF",]
|
64 |
|
65 |
+
MODELS_GUI_DEFAULT = "Azure GPT 3.5 Turbo" # "GPT 4 Turbo 1106-preview"
|
66 |
|
67 |
version_mapping_cost = {
|
68 |
'GPT 4 32k': 'GPT_4_32K',
|
|
|
70 |
'GPT 4 Turbo 0125-preview': 'GPT_4_TURBO_0125',
|
71 |
'GPT 4 Turbo 1106-preview': 'GPT_4_TURBO_1106',
|
72 |
'GPT 3.5 Instruct': 'GPT_3_5_INSTRUCT',
|
73 |
+
'GPT 3.5 Turbo': 'GPT_3_5',
|
74 |
|
75 |
'Azure GPT 4 32k': 'AZURE_GPT_4_32K',
|
76 |
'Azure GPT 4': 'AZURE_GPT_4',
|
77 |
'Azure GPT 4 Turbo 0125-preview': 'AZURE_GPT_4_TURBO_0125',
|
78 |
'Azure GPT 4 Turbo 1106-preview': 'AZURE_GPT_4_TURBO_1106',
|
79 |
'Azure GPT 3.5 Instruct': 'AZURE_GPT_3_5_INSTRUCT',
|
80 |
+
'Azure GPT 3.5 Turbo': 'AZURE_GPT_3_5',
|
81 |
|
82 |
'Gemini Pro': 'GEMINI_PRO',
|
83 |
'PaLM 2 text-unicorn@001': 'PALM2_TU_1',
|
84 |
'PaLM 2 text-bison@001': 'PALM2_TB_1',
|
85 |
'PaLM 2 text-bison@002': 'PALM2_TB_2',
|
86 |
|
87 |
+
'Mistral Large': 'MISTRAL_LARGE',
|
88 |
'Mistral Medium': 'MISTRAL_MEDIUM',
|
89 |
'Mistral Small': 'MISTRAL_SMALL',
|
90 |
+
'Open Mixtral 8x7B': 'OPEN_MIXTRAL_8X7B',
|
91 |
+
'Open Mistral 7B': 'OPEN_MISTRAL_7B',
|
92 |
|
93 |
'LOCAL Mixtral 8x7B Instruct v0.1': 'LOCAL_MIXTRAL_8X7B_INSTRUCT_V01',
|
94 |
'LOCAL Mistral 7B Instruct v0.2': 'LOCAL_MISTRAL_7B_INSTRUCT_V02',
|
|
|
104 |
'GPT 4 Turbo 0125-preview': has_key_openai,
|
105 |
'GPT 4': has_key_openai,
|
106 |
'GPT 4 32k': has_key_openai,
|
107 |
+
'GPT 3.5 Turbo': has_key_openai,
|
108 |
'GPT 3.5 Instruct': has_key_openai,
|
109 |
|
110 |
+
'Azure GPT 3.5 Turbo': has_key_azure_openai,
|
111 |
'Azure GPT 3.5 Instruct': has_key_azure_openai,
|
112 |
'Azure GPT 4': has_key_azure_openai,
|
113 |
'Azure GPT 4 Turbo 1106-preview': has_key_azure_openai,
|
|
|
119 |
'PaLM 2 text-unicorn@001': has_key_google_application_credentials,
|
120 |
'Gemini Pro': has_key_google_application_credentials,
|
121 |
|
|
|
122 |
'Mistral Small': has_key_mistral,
|
123 |
'Mistral Medium': has_key_mistral,
|
124 |
+
'Mistral Large': has_key_mistral,
|
125 |
+
'Open Mixtral 8x7B': has_key_mistral,
|
126 |
+
'Open Mistral 7B': has_key_mistral,
|
127 |
|
128 |
'LOCAL Mixtral 8x7B Instruct v0.1': True,
|
129 |
'LOCAL Mistral 7B Instruct v0.2': True,
|
|
|
136 |
def get_version_mapping_is_azure(cls, key):
|
137 |
version_mapping_is_azure = {
|
138 |
"GPT 4 Turbo 1106-preview": False,
|
139 |
+
"GPT 4 Turbo 0125-preview": False,
|
140 |
'GPT 4': False,
|
141 |
'GPT 4 32k': False,
|
142 |
+
'GPT 3.5 Turbo': False,
|
143 |
'GPT 3.5 Instruct': False,
|
144 |
|
145 |
+
'Azure GPT 3.5 Turbo': True,
|
146 |
'Azure GPT 3.5 Instruct': True,
|
147 |
'Azure GPT 4': True,
|
148 |
'Azure GPT 4 Turbo 1106-preview': True,
|
149 |
+
'Azure GPT 4 Turbo 0125-preview': True,
|
150 |
'Azure GPT 4 32k': True,
|
151 |
|
152 |
'PaLM 2 text-bison@001': False,
|
|
|
154 |
'PaLM 2 text-unicorn@001': False,
|
155 |
'Gemini Pro': False,
|
156 |
|
|
|
157 |
'Mistral Small': False,
|
158 |
'Mistral Medium': False,
|
159 |
+
'Mistral Large': False,
|
160 |
+
'Open Mixtral 8x7B': False,
|
161 |
+
'Open Mistral 7B': False,
|
162 |
|
163 |
'LOCAL Mixtral 8x7B Instruct v0.1': False,
|
164 |
'LOCAL Mistral 7B Instruct v0.2': False,
|
|
|
172 |
|
173 |
### OpenAI
|
174 |
if key == 'GPT_3_5':
|
175 |
+
return 'gpt-3.5-turbo-0125' #'gpt-3.5-turbo-1106'
|
176 |
|
177 |
elif key == 'GPT_3_5_INSTRUCT':
|
178 |
return 'gpt-3.5-turbo-instruct'
|
|
|
191 |
|
192 |
### Azure
|
193 |
elif key == 'AZURE_GPT_3_5':
|
194 |
+
return 'gpt-35-turbo-0125'
|
195 |
|
196 |
elif key == 'AZURE_GPT_3_5_INSTRUCT':
|
197 |
return 'gpt-35-turbo-instruct'
|
|
|
222 |
return "gemini-1.0-pro"
|
223 |
|
224 |
### Mistral
|
225 |
+
elif key == 'OPEN_MISTRAL_7B':
|
226 |
+
return "open-mistral-7b"
|
227 |
+
|
228 |
+
elif key == 'OPEN_MIXTRAL_8X7B':
|
229 |
+
return 'open-mixtral-8x7b'
|
230 |
|
231 |
elif key == 'MISTRAL_SMALL':
|
232 |
+
return 'mistral-small-latest'
|
233 |
|
234 |
elif key == 'MISTRAL_MEDIUM':
|
235 |
+
return 'mistral-medium-latest'
|
236 |
+
|
237 |
+
elif key == 'MISTRAL_LARGE':
|
238 |
+
return 'mistral-large-latest'
|
239 |
|
240 |
|
241 |
### Mistral LOCAL
|
vouchervision/prompt_catalog.py
CHANGED
@@ -18,7 +18,7 @@ class PromptCatalog:
|
|
18 |
|
19 |
|
20 |
def prompt_SLTP(self, rules_config_path, OCR=None, is_palm=False):
|
21 |
-
self.OCR = OCR
|
22 |
|
23 |
self.rules_config_path = rules_config_path
|
24 |
self.rules_config = self.load_rules_config()
|
@@ -48,9 +48,9 @@ class PromptCatalog:
|
|
48 |
The unstructured OCR text is:
|
49 |
{self.OCR}
|
50 |
Please populate the following JSON dictionary based on the rules and the unformatted OCR text:
|
51 |
-
{self.
|
52 |
-
{self.
|
53 |
-
{self.
|
54 |
"""
|
55 |
else:
|
56 |
prompt = f"""Please help me complete this text parsing task given the following rules and unstructured OCR text. Your task is to refactor the OCR text into a structured JSON dictionary that matches the structure specified in the following rules. Please follow the rules strictly.
|
@@ -62,13 +62,16 @@ class PromptCatalog:
|
|
62 |
The unstructured OCR text is:
|
63 |
{self.OCR}
|
64 |
Please populate the following JSON dictionary based on the rules and the unformatted OCR text:
|
65 |
-
{self.
|
66 |
"""
|
67 |
# xlsx_headers = self.generate_xlsx_headers(is_palm)
|
68 |
|
69 |
# return prompt, self.PromptJSONModel, self.n_fields, xlsx_headers
|
|
|
70 |
return prompt, self.dictionary_structure
|
71 |
|
|
|
|
|
72 |
|
73 |
def copy_prompt_template_to_new_dir(self, new_directory_path, rules_config_path):
|
74 |
# Ensure the target directory exists, create it if it doesn't
|
@@ -102,22 +105,31 @@ class PromptCatalog:
|
|
102 |
return structure_json_str
|
103 |
|
104 |
def create_structure(self, is_palm=False):
|
105 |
-
# Create fields for the Pydantic model dynamically
|
106 |
-
fields = {key: (str, Field(default=value, description=value)) for key, value in self.rules_list.items()}
|
107 |
|
108 |
-
# Dynamically create the Pydantic model
|
109 |
-
DynamicJSONParsingModel = create_model('SLTPvA', **fields)
|
110 |
-
DynamicJSONParsingModel_use = DynamicJSONParsingModel()
|
111 |
|
112 |
-
# Define the structure for the "Dictionary" section
|
113 |
-
dictionary_fields = {key: (str, Field(default='', description="")) for key in self.rules_list.keys()}
|
114 |
|
115 |
-
# Dynamically create the "Dictionary" Pydantic model
|
116 |
-
PromptJSONModel = create_model('PromptJSONModel', **dictionary_fields)
|
117 |
|
118 |
-
# Convert the model to JSON string (for demonstration)
|
119 |
-
dictionary_structure = PromptJSONModel().dict()
|
|
|
|
|
|
|
|
|
|
|
|
|
120 |
structure_json_str = json.dumps(dictionary_structure, sort_keys=False, indent=4)
|
|
|
|
|
|
|
121 |
return structure_json_str, dictionary_structure
|
122 |
|
123 |
|
|
|
18 |
|
19 |
|
20 |
def prompt_SLTP(self, rules_config_path, OCR=None, is_palm=False):
|
21 |
+
self.OCR = self.remove_colons_and_double_apostrophes(OCR)
|
22 |
|
23 |
self.rules_config_path = rules_config_path
|
24 |
self.rules_config = self.load_rules_config()
|
|
|
48 |
The unstructured OCR text is:
|
49 |
{self.OCR}
|
50 |
Please populate the following JSON dictionary based on the rules and the unformatted OCR text:
|
51 |
+
{self.dictionary_structure}
|
52 |
+
{self.dictionary_structure}
|
53 |
+
{self.dictionary_structure}
|
54 |
"""
|
55 |
else:
|
56 |
prompt = f"""Please help me complete this text parsing task given the following rules and unstructured OCR text. Your task is to refactor the OCR text into a structured JSON dictionary that matches the structure specified in the following rules. Please follow the rules strictly.
|
|
|
62 |
The unstructured OCR text is:
|
63 |
{self.OCR}
|
64 |
Please populate the following JSON dictionary based on the rules and the unformatted OCR text:
|
65 |
+
{self.dictionary_structure}
|
66 |
"""
|
67 |
# xlsx_headers = self.generate_xlsx_headers(is_palm)
|
68 |
|
69 |
# return prompt, self.PromptJSONModel, self.n_fields, xlsx_headers
|
70 |
+
# print(prompt)
|
71 |
return prompt, self.dictionary_structure
|
72 |
|
73 |
+
def remove_colons_and_double_apostrophes(self, text):
|
74 |
+
return text.replace(":", "").replace("\"", "")
|
75 |
|
76 |
def copy_prompt_template_to_new_dir(self, new_directory_path, rules_config_path):
|
77 |
# Ensure the target directory exists, create it if it doesn't
|
|
|
105 |
return structure_json_str
|
106 |
|
107 |
def create_structure(self, is_palm=False):
|
108 |
+
# # Create fields for the Pydantic model dynamically
|
109 |
+
# fields = {key: (str, Field(default=value, description=value)) for key, value in self.rules_list.items()}
|
110 |
|
111 |
+
# # Dynamically create the Pydantic model
|
112 |
+
# DynamicJSONParsingModel = create_model('SLTPvA', **fields)
|
113 |
+
# DynamicJSONParsingModel_use = DynamicJSONParsingModel()
|
114 |
|
115 |
+
# # Define the structure for the "Dictionary" section
|
116 |
+
# dictionary_fields = {key: (str, Field(default='', description="")) for key in self.rules_list.keys()}
|
117 |
|
118 |
+
# # Dynamically create the "Dictionary" Pydantic model
|
119 |
+
# PromptJSONModel = create_model('PromptJSONModel', **dictionary_fields)
|
120 |
|
121 |
+
# # Convert the model to JSON string (for demonstration)
|
122 |
+
# dictionary_structure = PromptJSONModel().dict()
|
123 |
+
# structure_json_str = json.dumps(dictionary_structure, sort_keys=False, indent=4)
|
124 |
+
|
125 |
+
# Directly create the dictionary structure with empty strings as default values
|
126 |
+
dictionary_structure = {key: '' for key in self.rules_list.keys()}
|
127 |
+
|
128 |
+
# Convert the dictionary to JSON string for demonstration if needed
|
129 |
structure_json_str = json.dumps(dictionary_structure, sort_keys=False, indent=4)
|
130 |
+
# print(structure_json_str)
|
131 |
+
# print(dictionary_structure)
|
132 |
+
|
133 |
return structure_json_str, dictionary_structure
|
134 |
|
135 |
|
vouchervision/tool_taxonomy_WFO.py
CHANGED
@@ -19,12 +19,19 @@ class WFONameMatcher:
|
|
19 |
self.is_enabled = tool_WFO
|
20 |
|
21 |
def extract_input_string(self, record):
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
|
|
27 |
|
|
|
|
|
|
|
|
|
|
|
|
|
28 |
return primary_input, secondary_input
|
29 |
|
30 |
def query_wfo_name_matching(self, input_string, check_homonyms=True, check_rank=True, accept_single_candidate=True):
|
@@ -46,6 +53,8 @@ class WFONameMatcher:
|
|
46 |
|
47 |
def query_and_process(self, record):
|
48 |
primary_input, secondary_input = self.extract_input_string(record)
|
|
|
|
|
49 |
|
50 |
# Query with primary input
|
51 |
primary_result = self.query_wfo_name_matching(primary_input)
|
|
|
19 |
self.is_enabled = tool_WFO
|
20 |
|
21 |
def extract_input_string(self, record):
|
22 |
+
if 'scientificName' in record and 'scientificNameAuthorship' in record:
|
23 |
+
primary_input = f"{record.get('scientificName', '').strip()} {record.get('scientificNameAuthorship', '').strip()}".strip()
|
24 |
+
elif 'speciesBinomialName' in record and 'speciesBinomialNameAuthorship' in record:
|
25 |
+
primary_input = f"{record.get('speciesBinomialName', '').strip()} {record.get('speciesBinomialNameAuthorship', '').strip()}".strip()
|
26 |
+
else:
|
27 |
+
return None, None
|
28 |
|
29 |
+
if 'genus' in record and 'specificEpithet' in record:
|
30 |
+
secondary_input = ' '.join(filter(None, [record.get('genus', '').strip(),
|
31 |
+
record.get('specificEpithet', '').strip()])).strip()
|
32 |
+
else:
|
33 |
+
return None, None
|
34 |
+
|
35 |
return primary_input, secondary_input
|
36 |
|
37 |
def query_wfo_name_matching(self, input_string, check_homonyms=True, check_rank=True, accept_single_candidate=True):
|
|
|
53 |
|
54 |
def query_and_process(self, record):
|
55 |
primary_input, secondary_input = self.extract_input_string(record)
|
56 |
+
if primary_input is None and secondary_input is None:
|
57 |
+
return self.NULL_DICT
|
58 |
|
59 |
# Query with primary input
|
60 |
primary_result = self.query_wfo_name_matching(primary_input)
|
vouchervision/utils_LLM.py
CHANGED
@@ -63,16 +63,13 @@ def run_tools(output, tool_WFO, tool_GEO, tool_wikipedia, json_file_path_wiki):
|
|
63 |
return output_WFO, WFO_record, output_GEO, GEO_record
|
64 |
|
65 |
|
|
|
66 |
def save_individual_prompt(prompt_template, txt_file_path_ind_prompt):
|
67 |
with open(txt_file_path_ind_prompt, 'w',encoding='utf-8') as file:
|
68 |
file.write(prompt_template)
|
69 |
|
70 |
|
71 |
|
72 |
-
def remove_colons_and_double_apostrophes(text):
|
73 |
-
return text.replace(":", "").replace("\"", "")
|
74 |
-
|
75 |
-
|
76 |
def sanitize_prompt(data):
|
77 |
if isinstance(data, dict):
|
78 |
return {sanitize_prompt(key): sanitize_prompt(value) for key, value in data.items()}
|
|
|
63 |
return output_WFO, WFO_record, output_GEO, GEO_record
|
64 |
|
65 |
|
66 |
+
|
67 |
def save_individual_prompt(prompt_template, txt_file_path_ind_prompt):
|
68 |
with open(txt_file_path_ind_prompt, 'w',encoding='utf-8') as file:
|
69 |
file.write(prompt_template)
|
70 |
|
71 |
|
72 |
|
|
|
|
|
|
|
|
|
73 |
def sanitize_prompt(data):
|
74 |
if isinstance(data, dict):
|
75 |
return {sanitize_prompt(key): sanitize_prompt(value) for key, value in data.items()}
|
vouchervision/utils_LLM_JSON_validation.py
CHANGED
@@ -11,7 +11,8 @@ def validate_and_align_JSON_keys_with_template(data, JSON_dict_structure):
|
|
11 |
if value is None:
|
12 |
data[key] = ''
|
13 |
elif isinstance(value, str):
|
14 |
-
if value.lower() in ['unknown',
|
|
|
15 |
'not provided in the text', 'not found in the text',
|
16 |
'not in the text', 'not provided', 'not found',
|
17 |
'not provided in the ocr', 'not found in the ocr',
|
|
|
11 |
if value is None:
|
12 |
data[key] = ''
|
13 |
elif isinstance(value, str):
|
14 |
+
if value.lower() in ['unknown','not provided', 'missing', 'na', 'none', 'n/a', 'null', 'unspecified',
|
15 |
+
'TBD',
|
16 |
'not provided in the text', 'not found in the text',
|
17 |
'not in the text', 'not provided', 'not found',
|
18 |
'not provided in the ocr', 'not found in the ocr',
|
vouchervision/utils_VoucherVision.py
CHANGED
@@ -14,7 +14,6 @@ from vouchervision.LLM_GoogleGemini import GoogleGeminiHandler
|
|
14 |
from vouchervision.LLM_MistralAI import MistralHandler
|
15 |
from vouchervision.LLM_local_cpu_MistralAI import LocalCPUMistralHandler
|
16 |
from vouchervision.LLM_local_MistralAI import LocalMistralHandler
|
17 |
-
from vouchervision.utils_LLM import remove_colons_and_double_apostrophes
|
18 |
from vouchervision.prompt_catalog import PromptCatalog
|
19 |
from vouchervision.model_maps import ModelMaps
|
20 |
from vouchervision.general_utils import get_cfg_from_full_path
|
@@ -32,7 +31,7 @@ from vouchervision.OCR_google_cloud_vision import OCREngine
|
|
32 |
|
33 |
class VoucherVision():
|
34 |
|
35 |
-
def __init__(self, cfg, logger, dir_home, path_custom_prompts, Project, Dirs, is_hf):
|
36 |
self.cfg = cfg
|
37 |
self.logger = logger
|
38 |
self.dir_home = dir_home
|
@@ -43,6 +42,9 @@ class VoucherVision():
|
|
43 |
self.prompt_version = None
|
44 |
self.is_hf = is_hf
|
45 |
|
|
|
|
|
|
|
46 |
# self.trOCR_model_version = "microsoft/trocr-large-handwritten"
|
47 |
# self.trOCR_model_version = "microsoft/trocr-base-handwritten"
|
48 |
# self.trOCR_model_version = "dh-unibe/trocr-medieval-escriptmask" # NOPE
|
@@ -686,9 +688,10 @@ class VoucherVision():
|
|
686 |
Copy_Prompt = PromptCatalog()
|
687 |
Copy_Prompt.copy_prompt_template_to_new_dir(self.Dirs.transcription_prompt, self.path_custom_prompts)
|
688 |
|
689 |
-
json_report
|
690 |
-
|
691 |
-
|
|
|
692 |
|
693 |
for i, path_to_crop in enumerate(self.img_paths):
|
694 |
self.update_progress_report_batch(progress_report, i)
|
@@ -701,9 +704,11 @@ class VoucherVision():
|
|
701 |
self.path_to_crop = path_to_crop
|
702 |
|
703 |
filename_without_extension, txt_file_path, txt_file_path_OCR, txt_file_path_OCR_bounds, jpg_file_path_OCR_helper, json_file_path_wiki, txt_file_path_ind_prompt = paths
|
704 |
-
json_report
|
|
|
705 |
self.perform_OCR_and_save_results(i, json_report, jpg_file_path_OCR_helper, txt_file_path_OCR, txt_file_path_OCR_bounds)
|
706 |
-
json_report
|
|
|
707 |
|
708 |
if not self.OCR:
|
709 |
self.n_failed_OCR += 1
|
@@ -713,7 +718,7 @@ class VoucherVision():
|
|
713 |
else:
|
714 |
### Format prompt
|
715 |
prompt = self.setup_prompt()
|
716 |
-
prompt = remove_colons_and_double_apostrophes(prompt)
|
717 |
|
718 |
### Send prompt to chosen LLM
|
719 |
self.logger.info(f'Waiting for {model_name} API call --- Using {MODEL_NAME_FORMATTED}')
|
@@ -747,8 +752,9 @@ class VoucherVision():
|
|
747 |
final_JSON_response, final_WFO_record, final_GEO_record = self.update_final_response(response_candidate, WFO_record, GEO_record, usage_report, MODEL_NAME_FORMATTED, paths, path_to_crop, nt_in, nt_out)
|
748 |
|
749 |
self.logger.info(f'Finished LLM call')
|
750 |
-
|
751 |
-
json_report
|
|
|
752 |
|
753 |
self.update_progress_report_final(progress_report)
|
754 |
final_JSON_response = self.parse_final_json_response(final_JSON_response)
|
@@ -758,22 +764,22 @@ class VoucherVision():
|
|
758 |
##################################################################################################################################
|
759 |
################################################## LLM Helper Funcs ##############################################################
|
760 |
##################################################################################################################################
|
761 |
-
def initialize_llm_model(self, cfg, logger, model_name, JSON_dict_structure, name_parts, is_azure=None, llm_object=None):
|
762 |
if 'LOCAL'in name_parts:
|
763 |
if ('MIXTRAL' in name_parts) or ('MISTRAL' in name_parts):
|
764 |
if 'CPU' in name_parts:
|
765 |
-
return LocalCPUMistralHandler(cfg, logger, model_name, JSON_dict_structure)
|
766 |
else:
|
767 |
-
return LocalMistralHandler(cfg, logger, model_name, JSON_dict_structure)
|
768 |
else:
|
769 |
if 'PALM2' in name_parts:
|
770 |
-
return GooglePalm2Handler(cfg, logger, model_name, JSON_dict_structure)
|
771 |
elif 'GEMINI' in name_parts:
|
772 |
-
return GoogleGeminiHandler(cfg, logger, model_name, JSON_dict_structure)
|
773 |
elif 'MISTRAL' in name_parts and ('LOCAL' not in name_parts):
|
774 |
-
return MistralHandler(cfg, logger, model_name, JSON_dict_structure)
|
775 |
else:
|
776 |
-
return OpenAIHandler(cfg, logger, model_name, JSON_dict_structure, is_azure, llm_object)
|
777 |
|
778 |
def setup_prompt(self):
|
779 |
Catalog = PromptCatalog()
|
|
|
14 |
from vouchervision.LLM_MistralAI import MistralHandler
|
15 |
from vouchervision.LLM_local_cpu_MistralAI import LocalCPUMistralHandler
|
16 |
from vouchervision.LLM_local_MistralAI import LocalMistralHandler
|
|
|
17 |
from vouchervision.prompt_catalog import PromptCatalog
|
18 |
from vouchervision.model_maps import ModelMaps
|
19 |
from vouchervision.general_utils import get_cfg_from_full_path
|
|
|
31 |
|
32 |
class VoucherVision():
|
33 |
|
34 |
+
def __init__(self, cfg, logger, dir_home, path_custom_prompts, Project, Dirs, is_hf, config_vals_for_permutation=None):
|
35 |
self.cfg = cfg
|
36 |
self.logger = logger
|
37 |
self.dir_home = dir_home
|
|
|
42 |
self.prompt_version = None
|
43 |
self.is_hf = is_hf
|
44 |
|
45 |
+
### config_vals_for_permutation allows you to set the starting temp, top_k, top_p, seed....
|
46 |
+
self.config_vals_for_permutation = config_vals_for_permutation
|
47 |
+
|
48 |
# self.trOCR_model_version = "microsoft/trocr-large-handwritten"
|
49 |
# self.trOCR_model_version = "microsoft/trocr-base-handwritten"
|
50 |
# self.trOCR_model_version = "dh-unibe/trocr-medieval-escriptmask" # NOPE
|
|
|
688 |
Copy_Prompt = PromptCatalog()
|
689 |
Copy_Prompt.copy_prompt_template_to_new_dir(self.Dirs.transcription_prompt, self.path_custom_prompts)
|
690 |
|
691 |
+
if json_report:
|
692 |
+
json_report.set_text(text_main=f'Loading {MODEL_NAME_FORMATTED}')
|
693 |
+
json_report.set_JSON({}, {}, {})
|
694 |
+
llm_model = self.initialize_llm_model(self.cfg, self.logger, MODEL_NAME_FORMATTED, self.JSON_dict_structure, name_parts, is_azure, self.llm, self.config_vals_for_permutation)
|
695 |
|
696 |
for i, path_to_crop in enumerate(self.img_paths):
|
697 |
self.update_progress_report_batch(progress_report, i)
|
|
|
704 |
self.path_to_crop = path_to_crop
|
705 |
|
706 |
filename_without_extension, txt_file_path, txt_file_path_OCR, txt_file_path_OCR_bounds, jpg_file_path_OCR_helper, json_file_path_wiki, txt_file_path_ind_prompt = paths
|
707 |
+
if json_report:
|
708 |
+
json_report.set_text(text_main='Starting OCR')
|
709 |
self.perform_OCR_and_save_results(i, json_report, jpg_file_path_OCR_helper, txt_file_path_OCR, txt_file_path_OCR_bounds)
|
710 |
+
if json_report:
|
711 |
+
json_report.set_text(text_main='Finished OCR')
|
712 |
|
713 |
if not self.OCR:
|
714 |
self.n_failed_OCR += 1
|
|
|
718 |
else:
|
719 |
### Format prompt
|
720 |
prompt = self.setup_prompt()
|
721 |
+
# prompt = remove_colons_and_double_apostrophes(prompt) # This is moved to utils_VV since it broke the json structure.
|
722 |
|
723 |
### Send prompt to chosen LLM
|
724 |
self.logger.info(f'Waiting for {model_name} API call --- Using {MODEL_NAME_FORMATTED}')
|
|
|
752 |
final_JSON_response, final_WFO_record, final_GEO_record = self.update_final_response(response_candidate, WFO_record, GEO_record, usage_report, MODEL_NAME_FORMATTED, paths, path_to_crop, nt_in, nt_out)
|
753 |
|
754 |
self.logger.info(f'Finished LLM call')
|
755 |
+
|
756 |
+
if json_report:
|
757 |
+
json_report.set_JSON(final_JSON_response, final_WFO_record, final_GEO_record)
|
758 |
|
759 |
self.update_progress_report_final(progress_report)
|
760 |
final_JSON_response = self.parse_final_json_response(final_JSON_response)
|
|
|
764 |
##################################################################################################################################
|
765 |
################################################## LLM Helper Funcs ##############################################################
|
766 |
##################################################################################################################################
|
767 |
+
def initialize_llm_model(self, cfg, logger, model_name, JSON_dict_structure, name_parts, is_azure=None, llm_object=None, config_vals_for_permutation=None):
|
768 |
if 'LOCAL'in name_parts:
|
769 |
if ('MIXTRAL' in name_parts) or ('MISTRAL' in name_parts):
|
770 |
if 'CPU' in name_parts:
|
771 |
+
return LocalCPUMistralHandler(cfg, logger, model_name, JSON_dict_structure, config_vals_for_permutation)
|
772 |
else:
|
773 |
+
return LocalMistralHandler(cfg, logger, model_name, JSON_dict_structure, config_vals_for_permutation)
|
774 |
else:
|
775 |
if 'PALM2' in name_parts:
|
776 |
+
return GooglePalm2Handler(cfg, logger, model_name, JSON_dict_structure, config_vals_for_permutation)
|
777 |
elif 'GEMINI' in name_parts:
|
778 |
+
return GoogleGeminiHandler(cfg, logger, model_name, JSON_dict_structure, config_vals_for_permutation)
|
779 |
elif 'MISTRAL' in name_parts and ('LOCAL' not in name_parts):
|
780 |
+
return MistralHandler(cfg, logger, model_name, JSON_dict_structure, config_vals_for_permutation)
|
781 |
else:
|
782 |
+
return OpenAIHandler(cfg, logger, model_name, JSON_dict_structure, is_azure, llm_object, config_vals_for_permutation)
|
783 |
|
784 |
def setup_prompt(self):
|
785 |
Catalog = PromptCatalog()
|
vouchervision/utils_VoucherVision_parallel.py
ADDED
@@ -0,0 +1,1022 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import openai
|
2 |
+
import os, json, glob, shutil, yaml, torch, logging
|
3 |
+
import openpyxl
|
4 |
+
from openpyxl import Workbook, load_workbook
|
5 |
+
from tqdm import tqdm
|
6 |
+
import vertexai
|
7 |
+
from transformers import TrOCRProcessor, VisionEncoderDecoderModel
|
8 |
+
from langchain_openai import AzureChatOpenAI
|
9 |
+
from google.oauth2 import service_account
|
10 |
+
from transformers import AutoTokenizer, AutoModel
|
11 |
+
|
12 |
+
from concurrent.futures import ThreadPoolExecutor, as_completed
|
13 |
+
from queue import Queue
|
14 |
+
import threading
|
15 |
+
|
16 |
+
from vouchervision.LLM_OpenAI import OpenAIHandler
|
17 |
+
from vouchervision.LLM_GooglePalm2 import GooglePalm2Handler
|
18 |
+
from vouchervision.LLM_GoogleGemini import GoogleGeminiHandler
|
19 |
+
from vouchervision.LLM_MistralAI import MistralHandler
|
20 |
+
from vouchervision.LLM_local_cpu_MistralAI import LocalCPUMistralHandler
|
21 |
+
from vouchervision.LLM_local_MistralAI import LocalMistralHandler
|
22 |
+
from vouchervision.prompt_catalog import PromptCatalog
|
23 |
+
from vouchervision.model_maps import ModelMaps
|
24 |
+
from vouchervision.general_utils import get_cfg_from_full_path
|
25 |
+
from vouchervision.OCR_google_cloud_vision import OCREngine
|
26 |
+
|
27 |
+
'''
|
28 |
+
* For the prefix_removal, the image names have 'MICH-V-' prior to the barcode, so that is used for matching
|
29 |
+
but removed for output.
|
30 |
+
* There is also code active to replace the LLM-predicted "Catalog Number" with the correct number since it is known.
|
31 |
+
The LLMs to usually assign the barcode to the correct field, but it's not needed since it is already known.
|
32 |
+
- Look for ####################### Catalog Number pre-defined
|
33 |
+
'''
|
34 |
+
|
35 |
+
|
36 |
+
|
37 |
+
class VoucherVision():
|
38 |
+
|
39 |
+
def __init__(self, cfg, logger, dir_home, path_custom_prompts, Project, Dirs, is_hf, config_vals_for_permutation=None):
|
40 |
+
self.cfg = cfg
|
41 |
+
self.logger = logger
|
42 |
+
self.dir_home = dir_home
|
43 |
+
self.path_custom_prompts = path_custom_prompts
|
44 |
+
self.Project = Project
|
45 |
+
self.Dirs = Dirs
|
46 |
+
self.headers = None
|
47 |
+
self.prompt_version = None
|
48 |
+
self.is_hf = is_hf
|
49 |
+
|
50 |
+
### config_vals_for_permutation allows you to set the starting temp, top_k, top_p, seed....
|
51 |
+
self.config_vals_for_permutation = config_vals_for_permutation
|
52 |
+
|
53 |
+
# self.trOCR_model_version = "microsoft/trocr-large-handwritten"
|
54 |
+
# self.trOCR_model_version = "microsoft/trocr-base-handwritten"
|
55 |
+
# self.trOCR_model_version = "dh-unibe/trocr-medieval-escriptmask" # NOPE
|
56 |
+
# self.trOCR_model_version = "dh-unibe/trocr-kurrent" # NOPE
|
57 |
+
# self.trOCR_model_version = "DunnBC22/trocr-base-handwritten-OCR-handwriting_recognition_v2" # NOPE
|
58 |
+
self.trOCR_processor = None
|
59 |
+
self.trOCR_model = None
|
60 |
+
|
61 |
+
self.set_API_keys()
|
62 |
+
self.setup()
|
63 |
+
|
64 |
+
|
65 |
+
def setup(self):
|
66 |
+
self.logger.name = f'[Transcription]'
|
67 |
+
self.logger.info(f'Setting up OCR and LLM')
|
68 |
+
|
69 |
+
self.trOCR_model_version = self.cfg['leafmachine']['project']['trOCR_model_path']
|
70 |
+
|
71 |
+
self.db_name = self.cfg['leafmachine']['project']['embeddings_database_name']
|
72 |
+
self.path_domain_knowledge = self.cfg['leafmachine']['project']['path_to_domain_knowledge_xlsx']
|
73 |
+
self.build_new_db = self.cfg['leafmachine']['project']['build_new_embeddings_database']
|
74 |
+
|
75 |
+
self.continue_run_from_partial_xlsx = self.cfg['leafmachine']['project']['continue_run_from_partial_xlsx']
|
76 |
+
|
77 |
+
self.prefix_removal = self.cfg['leafmachine']['project']['prefix_removal']
|
78 |
+
self.suffix_removal = self.cfg['leafmachine']['project']['suffix_removal']
|
79 |
+
self.catalog_numerical_only = self.cfg['leafmachine']['project']['catalog_numerical_only']
|
80 |
+
|
81 |
+
self.prompt_version0 = self.cfg['leafmachine']['project']['prompt_version']
|
82 |
+
self.use_domain_knowledge = self.cfg['leafmachine']['project']['use_domain_knowledge']
|
83 |
+
|
84 |
+
self.catalog_name_options = ["Catalog Number", "catalog_number", "catalogNumber"]
|
85 |
+
|
86 |
+
self.geo_headers = ["GEO_override_OCR", "GEO_method", "GEO_formatted_full_string", "GEO_decimal_lat",
|
87 |
+
"GEO_decimal_long","GEO_city", "GEO_county", "GEO_state",
|
88 |
+
"GEO_state_code", "GEO_country", "GEO_country_code", "GEO_continent",]
|
89 |
+
|
90 |
+
self.usage_headers = ["current_time", "inference_time_s", "tool_time_s","max_cpu", "max_ram_gb", "n_gpus", "max_gpu_load", "max_gpu_vram_gb","total_gpu_vram_gb","capability_score",]
|
91 |
+
|
92 |
+
self.wfo_headers = ["WFO_override_OCR", "WFO_exact_match","WFO_exact_match_name","WFO_best_match","WFO_candidate_names","WFO_placement"]
|
93 |
+
self.wfo_headers_no_lists = ["WFO_override_OCR", "WFO_exact_match","WFO_exact_match_name","WFO_best_match","WFO_placement"]
|
94 |
+
|
95 |
+
self.utility_headers = ["filename"] + self.wfo_headers + self.geo_headers + self.usage_headers + ["run_name", "prompt", "LLM", "tokens_in", "tokens_out", "LM2_collage", "OCR_method", "OCR_double", "OCR_trOCR", "path_to_crop","path_to_original","path_to_content","path_to_helper",]
|
96 |
+
# "WFO_override_OCR", "WFO_exact_match","WFO_exact_match_name","WFO_best_match","WFO_candidate_names","WFO_placement",
|
97 |
+
|
98 |
+
# "GEO_override_OCR", "GEO_method", "GEO_formatted_full_string", "GEO_decimal_lat",
|
99 |
+
# "GEO_decimal_long","GEO_city", "GEO_county", "GEO_state",
|
100 |
+
# "GEO_state_code", "GEO_country", "GEO_country_code", "GEO_continent",
|
101 |
+
|
102 |
+
# "tokens_in", "tokens_out", "path_to_crop","path_to_original","path_to_content","path_to_helper",]
|
103 |
+
|
104 |
+
# WFO_candidate_names is separate, bc it may be type --> list
|
105 |
+
|
106 |
+
self.do_create_OCR_helper_image = self.cfg['leafmachine']['do_create_OCR_helper_image']
|
107 |
+
|
108 |
+
self.map_prompt_versions()
|
109 |
+
self.map_dir_labels()
|
110 |
+
self.map_API_options()
|
111 |
+
# self.init_embeddings()
|
112 |
+
self.init_transcription_xlsx()
|
113 |
+
self.init_trOCR_model()
|
114 |
+
|
115 |
+
'''Logging'''
|
116 |
+
self.logger.info(f'Transcribing dataset --- {self.dir_labels}')
|
117 |
+
self.logger.info(f'Saving transcription batch to --- {self.path_transcription}')
|
118 |
+
self.logger.info(f'Saving individual transcription files to --- {self.Dirs.transcription_ind}')
|
119 |
+
self.logger.info(f'Starting transcription...')
|
120 |
+
self.logger.info(f' LLM MODEL --> {self.version_name}')
|
121 |
+
self.logger.info(f' Using Azure API --> {self.is_azure}')
|
122 |
+
self.logger.info(f' Model name passed to API --> {self.model_name}')
|
123 |
+
self.logger.info(f' API access token is found in PRIVATE_DATA.yaml --> {self.has_key}')
|
124 |
+
|
125 |
+
|
126 |
+
def init_trOCR_model(self):
|
127 |
+
lgr = logging.getLogger('transformers')
|
128 |
+
lgr.setLevel(logging.ERROR)
|
129 |
+
|
130 |
+
self.trOCR_processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-handwritten") # usually just the "microsoft/trocr-base-handwritten"
|
131 |
+
self.trOCR_model = VisionEncoderDecoderModel.from_pretrained(self.trOCR_model_version) # This matches the model
|
132 |
+
|
133 |
+
# Check for GPU availability
|
134 |
+
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
135 |
+
self.trOCR_model.to(self.device)
|
136 |
+
|
137 |
+
|
138 |
+
def map_API_options(self):
|
139 |
+
self.chat_version = self.cfg['leafmachine']['LLM_version']
|
140 |
+
|
141 |
+
# Get the required values from ModelMaps
|
142 |
+
self.model_name = ModelMaps.get_version_mapping_cost(self.chat_version)
|
143 |
+
self.is_azure = ModelMaps.get_version_mapping_is_azure(self.chat_version)
|
144 |
+
self.has_key = ModelMaps.get_version_has_key(self.chat_version, self.has_key_openai, self.has_key_azure_openai, self.has_key_google_application_credentials, self.has_key_mistral)
|
145 |
+
|
146 |
+
# Check if the version is supported
|
147 |
+
if self.model_name is None:
|
148 |
+
supported_LLMs = ", ".join(ModelMaps.get_models_gui_list())
|
149 |
+
raise Exception(f"Unsupported LLM: {self.chat_version}. Requires one of: {supported_LLMs}")
|
150 |
+
|
151 |
+
self.version_name = self.chat_version
|
152 |
+
|
153 |
+
|
154 |
+
def map_prompt_versions(self):
|
155 |
+
self.prompt_version_map = {
|
156 |
+
"Version 1": "prompt_v1_verbose",
|
157 |
+
}
|
158 |
+
self.prompt_version = self.prompt_version_map.get(self.prompt_version0, self.path_custom_prompts)
|
159 |
+
self.is_predefined_prompt = self.is_in_prompt_version_map(self.prompt_version)
|
160 |
+
|
161 |
+
|
162 |
+
def is_in_prompt_version_map(self, value):
|
163 |
+
return value in self.prompt_version_map.values()
|
164 |
+
|
165 |
+
|
166 |
+
def map_dir_labels(self):
|
167 |
+
if self.cfg['leafmachine']['use_RGB_label_images']:
|
168 |
+
self.dir_labels = os.path.join(self.Dirs.save_per_annotation_class,'label')
|
169 |
+
else:
|
170 |
+
self.dir_labels = self.Dirs.save_original
|
171 |
+
|
172 |
+
# Use glob to get all image paths in the directory
|
173 |
+
self.img_paths = glob.glob(os.path.join(self.dir_labels, "*"))
|
174 |
+
|
175 |
+
|
176 |
+
def load_rules_config(self):
|
177 |
+
with open(self.path_custom_prompts, 'r') as stream:
|
178 |
+
try:
|
179 |
+
return yaml.safe_load(stream)
|
180 |
+
except yaml.YAMLError as exc:
|
181 |
+
print(exc)
|
182 |
+
return None
|
183 |
+
|
184 |
+
|
185 |
+
def generate_xlsx_headers(self):
|
186 |
+
# Extract headers from the 'Dictionary' keys in the JSON template rules
|
187 |
+
# xlsx_headers = list(self.rules_config_json['rules']["Dictionary"].keys())
|
188 |
+
xlsx_headers = list(self.rules_config_json['rules'].keys())
|
189 |
+
xlsx_headers = xlsx_headers + self.utility_headers
|
190 |
+
return xlsx_headers
|
191 |
+
|
192 |
+
|
193 |
+
def init_transcription_xlsx(self):
|
194 |
+
# Initialize output file
|
195 |
+
self.path_transcription = os.path.join(self.Dirs.transcription,"transcribed.xlsx")
|
196 |
+
|
197 |
+
# else:
|
198 |
+
if not self.is_predefined_prompt:
|
199 |
+
# Load the rules configuration
|
200 |
+
self.rules_config_json = self.load_rules_config()
|
201 |
+
# Generate the headers from the configuration
|
202 |
+
self.headers = self.generate_xlsx_headers()
|
203 |
+
# Set the headers used to the dynamically generated headers
|
204 |
+
self.headers_used = 'CUSTOM'
|
205 |
+
else:
|
206 |
+
# If it's a predefined prompt, raise an exception as we don't have further instructions
|
207 |
+
raise ValueError("Predefined prompt is not handled in this context.")
|
208 |
+
|
209 |
+
self.create_or_load_excel_with_headers(os.path.join(self.Dirs.transcription,"transcribed.xlsx"), self.headers)
|
210 |
+
|
211 |
+
|
212 |
+
def create_or_load_excel_with_headers(self, file_path, headers, show_head=False):
|
213 |
+
output_dir_names = ['Archival_Components', 'Config_File', 'Cropped_Images', 'Logs', 'Original_Images', 'Transcription']
|
214 |
+
self.completed_specimens = []
|
215 |
+
|
216 |
+
# Check if the file exists and it's not None
|
217 |
+
if self.continue_run_from_partial_xlsx is not None and os.path.isfile(self.continue_run_from_partial_xlsx):
|
218 |
+
workbook = load_workbook(filename=self.continue_run_from_partial_xlsx)
|
219 |
+
sheet = workbook.active
|
220 |
+
show_head=True
|
221 |
+
# Identify the 'path_to_crop' column
|
222 |
+
try:
|
223 |
+
path_to_crop_col = headers.index('path_to_crop') + 1
|
224 |
+
path_to_original_col = headers.index('path_to_original') + 1
|
225 |
+
path_to_content_col = headers.index('path_to_content') + 1
|
226 |
+
path_to_helper_col = headers.index('path_to_helper') + 1
|
227 |
+
# self.completed_specimens = list(sheet.iter_cols(min_col=path_to_crop_col, max_col=path_to_crop_col, values_only=True, min_row=2))
|
228 |
+
except ValueError:
|
229 |
+
print("'path_to_crop' not found in the header row.")
|
230 |
+
|
231 |
+
path_to_crop = list(sheet.iter_cols(min_col=path_to_crop_col, max_col=path_to_crop_col, values_only=True, min_row=2))
|
232 |
+
path_to_original = list(sheet.iter_cols(min_col=path_to_original_col, max_col=path_to_original_col, values_only=True, min_row=2))
|
233 |
+
path_to_content = list(sheet.iter_cols(min_col=path_to_content_col, max_col=path_to_content_col, values_only=True, min_row=2))
|
234 |
+
path_to_helper = list(sheet.iter_cols(min_col=path_to_helper_col, max_col=path_to_helper_col, values_only=True, min_row=2))
|
235 |
+
others = [path_to_crop_col, path_to_original_col, path_to_content_col, path_to_helper_col]
|
236 |
+
jsons = [path_to_content_col, path_to_helper_col]
|
237 |
+
|
238 |
+
for cell in path_to_crop[0]:
|
239 |
+
old_path = cell
|
240 |
+
new_path = file_path
|
241 |
+
for dir_name in output_dir_names:
|
242 |
+
if dir_name in old_path:
|
243 |
+
old_path_parts = old_path.split(dir_name)
|
244 |
+
new_path_parts = new_path.split('Transcription')
|
245 |
+
updated_path = new_path_parts[0] + dir_name + old_path_parts[1]
|
246 |
+
self.completed_specimens.append(os.path.basename(updated_path))
|
247 |
+
print(f"{len(self.completed_specimens)} images are already completed")
|
248 |
+
|
249 |
+
### Copy the JSON files over
|
250 |
+
for colu in jsons:
|
251 |
+
cell = next(sheet.iter_rows(min_row=2, min_col=colu, max_col=colu))[0]
|
252 |
+
old_path = cell.value
|
253 |
+
new_path = file_path
|
254 |
+
|
255 |
+
old_path_parts = old_path.split('Transcription')
|
256 |
+
new_path_parts = new_path.split('Transcription')
|
257 |
+
updated_path = new_path_parts[0] + 'Transcription' + old_path_parts[1]
|
258 |
+
|
259 |
+
# Copy files
|
260 |
+
old_dir = os.path.dirname(old_path)
|
261 |
+
new_dir = os.path.dirname(updated_path)
|
262 |
+
|
263 |
+
# Check if old_dir exists and it's a directory
|
264 |
+
if os.path.exists(old_dir) and os.path.isdir(old_dir):
|
265 |
+
# Check if new_dir exists. If not, create it.
|
266 |
+
if not os.path.exists(new_dir):
|
267 |
+
os.makedirs(new_dir)
|
268 |
+
|
269 |
+
# Iterate through all files in old_dir and copy each to new_dir
|
270 |
+
for filename in os.listdir(old_dir):
|
271 |
+
shutil.copy2(os.path.join(old_dir, filename), new_dir) # copy2 preserves metadata
|
272 |
+
|
273 |
+
### Update the file names
|
274 |
+
for colu in others:
|
275 |
+
for row in sheet.iter_rows(min_row=2, min_col=colu, max_col=colu):
|
276 |
+
for cell in row:
|
277 |
+
old_path = cell.value
|
278 |
+
new_path = file_path
|
279 |
+
for dir_name in output_dir_names:
|
280 |
+
if dir_name in old_path:
|
281 |
+
old_path_parts = old_path.split(dir_name)
|
282 |
+
new_path_parts = new_path.split('Transcription')
|
283 |
+
updated_path = new_path_parts[0] + dir_name + old_path_parts[1]
|
284 |
+
cell.value = updated_path
|
285 |
+
show_head=True
|
286 |
+
|
287 |
+
|
288 |
+
else:
|
289 |
+
# Create a new workbook and select the active worksheet
|
290 |
+
workbook = Workbook()
|
291 |
+
sheet = workbook.active
|
292 |
+
|
293 |
+
# Write headers in the first row
|
294 |
+
for i, header in enumerate(headers, start=1):
|
295 |
+
sheet.cell(row=1, column=i, value=header)
|
296 |
+
self.completed_specimens = []
|
297 |
+
|
298 |
+
# Save the workbook
|
299 |
+
workbook.save(file_path)
|
300 |
+
|
301 |
+
if show_head:
|
302 |
+
print("continue_run_from_partial_xlsx:")
|
303 |
+
for i, row in enumerate(sheet.iter_rows(values_only=True)):
|
304 |
+
print(row)
|
305 |
+
if i == 3: # print the first 5 rows (0-indexed)
|
306 |
+
print("\n")
|
307 |
+
break
|
308 |
+
|
309 |
+
|
310 |
+
def add_data_to_excel_from_response(self, Dirs, path_transcription, response, WFO_record, GEO_record, usage_report,
|
311 |
+
MODEL_NAME_FORMATTED, filename_without_extension, path_to_crop, path_to_content, path_to_helper, nt_in, nt_out):
|
312 |
+
|
313 |
+
|
314 |
+
wb = openpyxl.load_workbook(path_transcription)
|
315 |
+
sheet = wb.active
|
316 |
+
|
317 |
+
# find the next empty row
|
318 |
+
next_row = sheet.max_row + 1
|
319 |
+
|
320 |
+
if isinstance(response, str):
|
321 |
+
try:
|
322 |
+
response = json.loads(response)
|
323 |
+
except json.JSONDecodeError:
|
324 |
+
print(f"Failed to parse response: {response}")
|
325 |
+
return
|
326 |
+
|
327 |
+
# iterate over headers in the first row
|
328 |
+
for i, header in enumerate(sheet[1], start=1):
|
329 |
+
# check if header value is in response keys
|
330 |
+
if (header.value in response) and (header.value not in self.catalog_name_options): ####################### Catalog Number pre-defined
|
331 |
+
# check if the response value is a dictionary
|
332 |
+
if isinstance(response[header.value], dict):
|
333 |
+
# if it is a dictionary, extract the 'value' field
|
334 |
+
cell_value = response[header.value].get('value', '')
|
335 |
+
else:
|
336 |
+
# if it's not a dictionary, use it directly
|
337 |
+
cell_value = response[header.value]
|
338 |
+
|
339 |
+
try:
|
340 |
+
# write the value to the cell
|
341 |
+
sheet.cell(row=next_row, column=i, value=cell_value)
|
342 |
+
except:
|
343 |
+
sheet.cell(row=next_row, column=i, value=cell_value[0])
|
344 |
+
|
345 |
+
elif header.value in self.catalog_name_options:
|
346 |
+
# if self.prefix_removal:
|
347 |
+
# filename_without_extension = filename_without_extension.replace(self.prefix_removal, "")
|
348 |
+
# if self.suffix_removal:
|
349 |
+
# filename_without_extension = filename_without_extension.replace(self.suffix_removal, "")
|
350 |
+
# if self.catalog_numerical_only:
|
351 |
+
# filename_without_extension = self.remove_non_numbers(filename_without_extension)
|
352 |
+
sheet.cell(row=next_row, column=i, value=filename_without_extension)
|
353 |
+
elif header.value == "path_to_crop":
|
354 |
+
sheet.cell(row=next_row, column=i, value=path_to_crop)
|
355 |
+
elif header.value == "path_to_original":
|
356 |
+
if self.cfg['leafmachine']['use_RGB_label_images']:
|
357 |
+
fname = os.path.basename(path_to_crop)
|
358 |
+
base = os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(path_to_crop))))
|
359 |
+
path_to_original = os.path.join(base, 'Original_Images', fname)
|
360 |
+
sheet.cell(row=next_row, column=i, value=path_to_original)
|
361 |
+
else:
|
362 |
+
fname = os.path.basename(path_to_crop)
|
363 |
+
base = os.path.dirname(os.path.dirname(path_to_crop))
|
364 |
+
path_to_original = os.path.join(base, 'Original_Images', fname)
|
365 |
+
sheet.cell(row=next_row, column=i, value=path_to_original)
|
366 |
+
elif header.value == "path_to_content":
|
367 |
+
sheet.cell(row=next_row, column=i, value=path_to_content)
|
368 |
+
elif header.value == "path_to_helper":
|
369 |
+
sheet.cell(row=next_row, column=i, value=path_to_helper)
|
370 |
+
elif header.value == "tokens_in":
|
371 |
+
sheet.cell(row=next_row, column=i, value=nt_in)
|
372 |
+
elif header.value == "tokens_out":
|
373 |
+
sheet.cell(row=next_row, column=i, value=nt_out)
|
374 |
+
elif header.value == "filename":
|
375 |
+
sheet.cell(row=next_row, column=i, value=filename_without_extension)
|
376 |
+
elif header.value == "prompt":
|
377 |
+
sheet.cell(row=next_row, column=i, value=os.path.basename(self.path_custom_prompts))
|
378 |
+
elif header.value == "run_name":
|
379 |
+
sheet.cell(row=next_row, column=i, value=Dirs.run_name)
|
380 |
+
elif header.value == "LM2_collage":
|
381 |
+
sheet.cell(row=next_row, column=i, value=self.cfg['leafmachine']['use_RGB_label_images'])
|
382 |
+
elif header.value == "OCR_method":
|
383 |
+
value_to_insert = self.cfg['leafmachine']['project']['OCR_option']
|
384 |
+
if isinstance(value_to_insert, list):
|
385 |
+
value_to_insert = '|'.join(map(str, value_to_insert))
|
386 |
+
sheet.cell(row=next_row, column=i, value=value_to_insert)
|
387 |
+
elif header.value == "OCR_double":
|
388 |
+
sheet.cell(row=next_row, column=i, value=self.cfg['leafmachine']['project']['double_OCR'])
|
389 |
+
elif header.value == "OCR_trOCR":
|
390 |
+
sheet.cell(row=next_row, column=i, value=self.cfg['leafmachine']['project']['do_use_trOCR'])
|
391 |
+
# "WFO_exact_match","WFO_exact_match_name","WFO_best_match","WFO_candidate_names","WFO_placement"
|
392 |
+
elif header.value in self.wfo_headers_no_lists:
|
393 |
+
sheet.cell(row=next_row, column=i, value=WFO_record.get(header.value, ''))
|
394 |
+
# elif header.value == "WFO_exact_match":
|
395 |
+
# sheet.cell(row=next_row, column=i, value= WFO_record.get("WFO_exact_match",''))
|
396 |
+
# elif header.value == "WFO_exact_match_name":
|
397 |
+
# sheet.cell(row=next_row, column=i, value= WFO_record.get("WFO_exact_match_name",''))
|
398 |
+
# elif header.value == "WFO_best_match":
|
399 |
+
# sheet.cell(row=next_row, column=i, value= WFO_record.get("WFO_best_match",''))
|
400 |
+
# elif header.value == "WFO_placement":
|
401 |
+
# sheet.cell(row=next_row, column=i, value= WFO_record.get("WFO_placement",''))
|
402 |
+
elif header.value == "WFO_candidate_names":
|
403 |
+
candidate_names = WFO_record.get("WFO_candidate_names", '')
|
404 |
+
# Check if candidate_names is a list and convert to a string if it is
|
405 |
+
if isinstance(candidate_names, list):
|
406 |
+
candidate_names_str = '|'.join(candidate_names)
|
407 |
+
else:
|
408 |
+
candidate_names_str = candidate_names
|
409 |
+
sheet.cell(row=next_row, column=i, value=candidate_names_str)
|
410 |
+
|
411 |
+
# "GEO_method", "GEO_formatted_full_string", "GEO_decimal_lat", "GEO_decimal_long",
|
412 |
+
# "GEO_city", "GEO_county", "GEO_state", "GEO_state_code", "GEO_country", "GEO_country_code", "GEO_continent"
|
413 |
+
elif header.value in self.geo_headers:
|
414 |
+
sheet.cell(row=next_row, column=i, value=GEO_record.get(header.value, ''))
|
415 |
+
|
416 |
+
elif header.value in self.usage_headers:
|
417 |
+
sheet.cell(row=next_row, column=i, value=usage_report.get(header.value, ''))
|
418 |
+
|
419 |
+
elif header.value == "LLM":
|
420 |
+
sheet.cell(row=next_row, column=i, value=MODEL_NAME_FORMATTED)
|
421 |
+
|
422 |
+
# save the workbook
|
423 |
+
wb.save(path_transcription)
|
424 |
+
|
425 |
+
|
426 |
+
def has_API_key(self, val):
|
427 |
+
return isinstance(val, str) and bool(val.strip())
|
428 |
+
# if val != '':
|
429 |
+
# return True
|
430 |
+
# else:
|
431 |
+
# return False
|
432 |
+
|
433 |
+
|
434 |
+
def get_google_credentials(self): # Also used for google drive
|
435 |
+
if self.is_hf:
|
436 |
+
creds_json_str = os.getenv('GOOGLE_APPLICATION_CREDENTIALS')
|
437 |
+
credentials = service_account.Credentials.from_service_account_info(json.loads(creds_json_str))
|
438 |
+
return credentials
|
439 |
+
else:
|
440 |
+
with open(self.cfg_private['google']['GOOGLE_APPLICATION_CREDENTIALS'], 'r') as file:
|
441 |
+
data = json.load(file)
|
442 |
+
creds_json_str = json.dumps(data)
|
443 |
+
credentials = service_account.Credentials.from_service_account_info(json.loads(creds_json_str))
|
444 |
+
os.environ['GOOGLE_APPLICATION_CREDENTIALS'] = creds_json_str
|
445 |
+
return credentials
|
446 |
+
|
447 |
+
|
448 |
+
def set_API_keys(self):
|
449 |
+
if self.is_hf:
|
450 |
+
self.dir_home = os.path.dirname(os.path.dirname(__file__))
|
451 |
+
self.path_cfg_private = None
|
452 |
+
self.cfg_private = None
|
453 |
+
|
454 |
+
k_openai = os.getenv('OPENAI_API_KEY')
|
455 |
+
k_openai_azure = os.getenv('AZURE_API_VERSION')
|
456 |
+
|
457 |
+
k_google_project_id = os.getenv('GOOGLE_PROJECT_ID')
|
458 |
+
k_google_location = os.getenv('GOOGLE_LOCATION')
|
459 |
+
k_google_application_credentials = os.getenv('GOOGLE_APPLICATION_CREDENTIALS')
|
460 |
+
|
461 |
+
k_mistral = os.getenv('MISTRAL_API_KEY')
|
462 |
+
k_here = os.getenv('HERE_API_KEY')
|
463 |
+
k_opencage = os.getenv('open_cage_geocode')
|
464 |
+
else:
|
465 |
+
self.dir_home = os.path.dirname(os.path.dirname(__file__))
|
466 |
+
self.path_cfg_private = os.path.join(self.dir_home, 'PRIVATE_DATA.yaml')
|
467 |
+
self.cfg_private = get_cfg_from_full_path(self.path_cfg_private)
|
468 |
+
|
469 |
+
k_openai = self.cfg_private['openai']['OPENAI_API_KEY']
|
470 |
+
k_openai_azure = self.cfg_private['openai_azure']['OPENAI_API_KEY_AZURE']
|
471 |
+
|
472 |
+
k_google_project_id = self.cfg_private['google']['GOOGLE_PROJECT_ID']
|
473 |
+
k_google_location = self.cfg_private['google']['GOOGLE_LOCATION']
|
474 |
+
k_google_application_credentials = self.cfg_private['google']['GOOGLE_APPLICATION_CREDENTIALS']
|
475 |
+
|
476 |
+
k_mistral = self.cfg_private['mistral']['MISTRAL_API_KEY']
|
477 |
+
k_here = self.cfg_private['here']['API_KEY']
|
478 |
+
k_opencage = self.cfg_private['open_cage_geocode']['API_KEY']
|
479 |
+
|
480 |
+
|
481 |
+
|
482 |
+
self.has_key_openai = self.has_API_key(k_openai)
|
483 |
+
self.has_key_azure_openai = self.has_API_key(k_openai_azure)
|
484 |
+
self.llm = None
|
485 |
+
|
486 |
+
self.has_key_google_project_id = self.has_API_key(k_google_project_id)
|
487 |
+
self.has_key_google_location = self.has_API_key(k_google_location)
|
488 |
+
self.has_key_google_application_credentials = self.has_API_key(k_google_application_credentials)
|
489 |
+
|
490 |
+
self.has_key_mistral = self.has_API_key(k_mistral)
|
491 |
+
self.has_key_here = self.has_API_key(k_here)
|
492 |
+
self.has_key_open_cage_geocode = self.has_API_key(k_opencage)
|
493 |
+
|
494 |
+
|
495 |
+
|
496 |
+
### Google - OCR, Palm2, Gemini
|
497 |
+
if self.has_key_google_application_credentials and self.has_key_google_project_id and self.has_key_google_location:
|
498 |
+
if self.is_hf:
|
499 |
+
vertexai.init(project=os.getenv('GOOGLE_PROJECT_ID'), location=os.getenv('GOOGLE_LOCATION'), credentials=self.get_google_credentials())
|
500 |
+
else:
|
501 |
+
vertexai.init(project=k_google_project_id, location=k_google_location, credentials=self.get_google_credentials())
|
502 |
+
os.environ['GOOGLE_API_KEY'] = self.cfg_private['google']['GOOGLE_PALM_API']
|
503 |
+
|
504 |
+
|
505 |
+
### OpenAI
|
506 |
+
if self.has_key_openai:
|
507 |
+
if self.is_hf:
|
508 |
+
openai.api_key = os.getenv('OPENAI_API_KEY')
|
509 |
+
else:
|
510 |
+
openai.api_key = self.cfg_private['openai']['OPENAI_API_KEY']
|
511 |
+
os.environ["OPENAI_API_KEY"] = self.cfg_private['openai']['OPENAI_API_KEY']
|
512 |
+
|
513 |
+
|
514 |
+
### OpenAI - Azure
|
515 |
+
if self.has_key_azure_openai:
|
516 |
+
if self.is_hf:
|
517 |
+
# Initialize the Azure OpenAI client
|
518 |
+
self.llm = AzureChatOpenAI(
|
519 |
+
deployment_name = 'gpt-35-turbo',#'gpt-35-turbo',
|
520 |
+
openai_api_version = os.getenv('AZURE_API_VERSION'),
|
521 |
+
openai_api_key = os.getenv('AZURE_API_KEY'),
|
522 |
+
azure_endpoint = os.getenv('AZURE_API_BASE'),
|
523 |
+
openai_organization = os.getenv('AZURE_ORGANIZATION'),
|
524 |
+
)
|
525 |
+
|
526 |
+
else:
|
527 |
+
# Initialize the Azure OpenAI client
|
528 |
+
self.llm = AzureChatOpenAI(
|
529 |
+
deployment_name = 'gpt-35-turbo',#'gpt-35-turbo',
|
530 |
+
openai_api_version = self.cfg_private['openai_azure']['OPENAI_API_VERSION'],
|
531 |
+
openai_api_key = self.cfg_private['openai_azure']['OPENAI_API_KEY_AZURE'],
|
532 |
+
azure_endpoint = self.cfg_private['openai_azure']['OPENAI_API_BASE'],
|
533 |
+
openai_organization = self.cfg_private['openai_azure']['OPENAI_ORGANIZATION'],
|
534 |
+
)
|
535 |
+
|
536 |
+
|
537 |
+
### Mistral
|
538 |
+
if self.has_key_mistral:
|
539 |
+
if self.is_hf:
|
540 |
+
pass # Already set
|
541 |
+
else:
|
542 |
+
os.environ['MISTRAL_API_KEY'] = self.cfg_private['mistral']['MISTRAL_API_KEY']
|
543 |
+
|
544 |
+
|
545 |
+
### HERE
|
546 |
+
if self.has_key_here:
|
547 |
+
if self.is_hf:
|
548 |
+
pass # Already set
|
549 |
+
else:
|
550 |
+
os.environ['HERE_APP_ID'] = self.cfg_private['here']['APP_ID']
|
551 |
+
os.environ['HERE_API_KEY'] = self.cfg_private['here']['API_KEY']
|
552 |
+
|
553 |
+
|
554 |
+
### HERE
|
555 |
+
if self.has_key_open_cage_geocode:
|
556 |
+
if self.is_hf:
|
557 |
+
pass # Already set
|
558 |
+
else:
|
559 |
+
os.environ['OPENCAGE_API_KEY'] = self.cfg_private['open_cage_geocode']['API_KEY']
|
560 |
+
|
561 |
+
|
562 |
+
|
563 |
+
def clean_catalog_number(self, data, filename_without_extension):
|
564 |
+
#Cleans up the catalog number in data if it's a dict
|
565 |
+
|
566 |
+
def modify_catalog_key(catalog_key, filename_without_extension, data):
|
567 |
+
# Helper function to apply modifications on catalog number
|
568 |
+
if catalog_key not in data:
|
569 |
+
new_data = {catalog_key: None}
|
570 |
+
data = {**new_data, **data}
|
571 |
+
|
572 |
+
if self.prefix_removal:
|
573 |
+
filename_without_extension = filename_without_extension.replace(self.prefix_removal, "")
|
574 |
+
if self.suffix_removal:
|
575 |
+
filename_without_extension = filename_without_extension.replace(self.suffix_removal, "")
|
576 |
+
if self.catalog_numerical_only:
|
577 |
+
filename_without_extension = self.remove_non_numbers(data[catalog_key])
|
578 |
+
data[catalog_key] = filename_without_extension
|
579 |
+
return data
|
580 |
+
|
581 |
+
if isinstance(data, dict):
|
582 |
+
if self.headers_used == 'HEADERS_v1_n22':
|
583 |
+
return modify_catalog_key("Catalog Number", filename_without_extension, data)
|
584 |
+
elif self.headers_used in ['HEADERS_v2_n26', 'CUSTOM']:
|
585 |
+
return modify_catalog_key("filename", filename_without_extension, data)
|
586 |
+
else:
|
587 |
+
raise ValueError("Invalid headers used.")
|
588 |
+
else:
|
589 |
+
raise TypeError("Data is not of type dict.")
|
590 |
+
|
591 |
+
|
592 |
+
def write_json_to_file(self, filepath, data):
|
593 |
+
'''Writes dictionary data to a JSON file.'''
|
594 |
+
with open(filepath, 'w') as txt_file:
|
595 |
+
if isinstance(data, dict):
|
596 |
+
data = json.dumps(data, indent=4, sort_keys=False)
|
597 |
+
txt_file.write(data)
|
598 |
+
|
599 |
+
|
600 |
+
# def create_null_json(self):
|
601 |
+
# return {}
|
602 |
+
|
603 |
+
|
604 |
+
def remove_non_numbers(self, s):
|
605 |
+
return ''.join([char for char in s if char.isdigit()])
|
606 |
+
|
607 |
+
|
608 |
+
def create_null_row(self, filename_without_extension, path_to_crop, path_to_content, path_to_helper):
|
609 |
+
json_dict = {header: '' for header in self.headers}
|
610 |
+
for header, value in json_dict.items():
|
611 |
+
if header == "path_to_crop":
|
612 |
+
json_dict[header] = path_to_crop
|
613 |
+
elif header == "path_to_original":
|
614 |
+
fname = os.path.basename(path_to_crop)
|
615 |
+
base = os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(path_to_crop))))
|
616 |
+
path_to_original = os.path.join(base, 'Original_Images', fname)
|
617 |
+
json_dict[header] = path_to_original
|
618 |
+
elif header == "path_to_content":
|
619 |
+
json_dict[header] = path_to_content
|
620 |
+
elif header == "path_to_helper":
|
621 |
+
json_dict[header] = path_to_helper
|
622 |
+
elif header == "filename":
|
623 |
+
json_dict[header] = filename_without_extension
|
624 |
+
|
625 |
+
# "WFO_exact_match","WFO_exact_match_name","WFO_best_match","WFO_candidate_names","WFO_placement"
|
626 |
+
elif header == "WFO_exact_match":
|
627 |
+
json_dict[header] =''
|
628 |
+
elif header == "WFO_exact_match_name":
|
629 |
+
json_dict[header] = ''
|
630 |
+
elif header == "WFO_best_match":
|
631 |
+
json_dict[header] = ''
|
632 |
+
elif header == "WFO_candidate_names":
|
633 |
+
json_dict[header] = ''
|
634 |
+
elif header == "WFO_placement":
|
635 |
+
json_dict[header] = ''
|
636 |
+
return json_dict
|
637 |
+
|
638 |
+
|
639 |
+
##################################################################################################################################
|
640 |
+
################################################## OCR ##################################################################
|
641 |
+
##################################################################################################################################
|
642 |
+
def perform_OCR_and_save_results(self, image_index, json_report, jpg_file_path_OCR_helper, txt_file_path_OCR, txt_file_path_OCR_bounds):
|
643 |
+
self.logger.info(f'Working on {image_index + 1}/{len(self.img_paths)} --- Starting OCR')
|
644 |
+
# self.OCR - None
|
645 |
+
|
646 |
+
### Process_image() runs the OCR for text, handwriting, trOCR AND creates the overlay image
|
647 |
+
ocr_google = OCREngine(self.logger, json_report, self.dir_home, self.is_hf, self.path_to_crop, self.cfg, self.trOCR_model_version, self.trOCR_model, self.trOCR_processor, self.device)
|
648 |
+
ocr_google.process_image(self.do_create_OCR_helper_image, self.logger)
|
649 |
+
self.OCR = ocr_google.OCR
|
650 |
+
self.logger.info(f"Complete OCR text for LLM prompt:\n\n{self.OCR}\n\n")
|
651 |
+
|
652 |
+
self.write_json_to_file(txt_file_path_OCR, ocr_google.OCR_JSON_to_file)
|
653 |
+
|
654 |
+
self.logger.info(f'Working on {image_index + 1}/{len(self.img_paths)} --- Finished OCR')
|
655 |
+
|
656 |
+
if len(self.OCR) > 0:
|
657 |
+
ocr_google.overlay_image.save(jpg_file_path_OCR_helper)
|
658 |
+
|
659 |
+
OCR_bounds = {}
|
660 |
+
if ocr_google.hand_text_to_box_mapping is not None:
|
661 |
+
OCR_bounds['OCR_bounds_handwritten'] = ocr_google.hand_text_to_box_mapping
|
662 |
+
|
663 |
+
if ocr_google.normal_text_to_box_mapping is not None:
|
664 |
+
OCR_bounds['OCR_bounds_printed'] = ocr_google.normal_text_to_box_mapping
|
665 |
+
|
666 |
+
if ocr_google.trOCR_text_to_box_mapping is not None:
|
667 |
+
OCR_bounds['OCR_bounds_trOCR'] = ocr_google.trOCR_text_to_box_mapping
|
668 |
+
|
669 |
+
self.write_json_to_file(txt_file_path_OCR_bounds, OCR_bounds)
|
670 |
+
self.logger.info(f'Working on {image_index + 1}/{len(self.img_paths)} --- Saved OCR Overlay Image')
|
671 |
+
else:
|
672 |
+
pass ########################################################################################################################### fix logic for no OCR
|
673 |
+
|
674 |
+
##################################################################################################################################
|
675 |
+
####################################################### LLM Switchboard ########################################################
|
676 |
+
##################################################################################################################################
|
677 |
+
def send_to_LLM(self, is_azure, progress_report, json_report, model_name):
|
678 |
+
self.n_failed_LLM_calls = 0
|
679 |
+
self.n_failed_OCR = 0
|
680 |
+
|
681 |
+
final_JSON_response = None
|
682 |
+
final_WFO_record = None
|
683 |
+
final_GEO_record = None
|
684 |
+
|
685 |
+
self.initialize_token_counters()
|
686 |
+
self.update_progress_report_initial(progress_report)
|
687 |
+
|
688 |
+
MODEL_NAME_FORMATTED = ModelMaps.get_API_name(model_name)
|
689 |
+
name_parts = model_name.split("_")
|
690 |
+
|
691 |
+
self.setup_JSON_dict_structure()
|
692 |
+
|
693 |
+
Copy_Prompt = PromptCatalog()
|
694 |
+
Copy_Prompt.copy_prompt_template_to_new_dir(self.Dirs.transcription_prompt, self.path_custom_prompts)
|
695 |
+
|
696 |
+
if json_report:
|
697 |
+
json_report.set_text(text_main=f'Loading {MODEL_NAME_FORMATTED}')
|
698 |
+
json_report.set_JSON({}, {}, {})
|
699 |
+
# llm_model = self.initialize_llm_model(self.cfg, self.logger, MODEL_NAME_FORMATTED, self.JSON_dict_structure, name_parts, is_azure, self.llm, self.config_vals_for_permutation)
|
700 |
+
|
701 |
+
results_queue = Queue()
|
702 |
+
|
703 |
+
if json_report:
|
704 |
+
json_report.set_text(text_main='Sending batch to OCR and LLM')
|
705 |
+
|
706 |
+
num_files = len(self.img_paths)
|
707 |
+
# num_threads = min(num_files, 128)
|
708 |
+
num_threads = 128
|
709 |
+
counter = AtomicCounter()
|
710 |
+
|
711 |
+
# Setup for parallel execution
|
712 |
+
with ThreadPoolExecutor(max_workers=num_threads) as executor:
|
713 |
+
futures = [executor.submit(self.send_to_LLM_worker,
|
714 |
+
path_to_crop,
|
715 |
+
results_queue,
|
716 |
+
model_name,
|
717 |
+
MODEL_NAME_FORMATTED,
|
718 |
+
name_parts,
|
719 |
+
is_azure,
|
720 |
+
i
|
721 |
+
) for i, path_to_crop in enumerate(self.img_paths)]
|
722 |
+
for future in tqdm(as_completed(futures), total=len(futures), desc="Processing", unit="task"):
|
723 |
+
try:
|
724 |
+
# Here, you could also directly process results if they were not being put in a queue
|
725 |
+
future.result() # Forces a wait on the future and re-raises any exceptions
|
726 |
+
new_value = counter.inc()
|
727 |
+
try:
|
728 |
+
if json_report:
|
729 |
+
current_value = counter.value
|
730 |
+
json_report.set_text(text_main=f'Completed {current_value} of {num_files}')
|
731 |
+
except:
|
732 |
+
pass
|
733 |
+
except Exception as e:
|
734 |
+
# Log the error, possibly mark the task for retry, or handle it as necessary
|
735 |
+
print(f"A task failed with exception: {e}")
|
736 |
+
|
737 |
+
# Process results from the queue
|
738 |
+
while not results_queue.empty():
|
739 |
+
response_candidate, nt_in, nt_out, WFO_record, GEO_record, usage_report, path_to_crop, paths = results_queue.get()
|
740 |
+
|
741 |
+
self.n_failed_LLM_calls += 1 if response_candidate is None else 0
|
742 |
+
|
743 |
+
### Estimate n tokens returned
|
744 |
+
self.logger.info(f'Prompt tokens IN --- {nt_in}')
|
745 |
+
self.logger.info(f'Prompt tokens OUT --- {nt_out}')
|
746 |
+
|
747 |
+
self.update_token_counters(nt_in, nt_out)
|
748 |
+
|
749 |
+
final_JSON_response, final_WFO_record, final_GEO_record = self.update_final_response(response_candidate, WFO_record, GEO_record, usage_report, MODEL_NAME_FORMATTED, paths, path_to_crop, nt_in, nt_out)
|
750 |
+
|
751 |
+
self.logger.info(f'Finished LLM call')
|
752 |
+
|
753 |
+
if json_report:
|
754 |
+
json_report.set_JSON(final_JSON_response, final_WFO_record, final_GEO_record)
|
755 |
+
|
756 |
+
if json_report:
|
757 |
+
json_report.set_text(text_main='Finished!')
|
758 |
+
|
759 |
+
self.update_progress_report_final(progress_report)
|
760 |
+
final_JSON_response = self.parse_final_json_response(final_JSON_response)
|
761 |
+
return final_JSON_response, final_WFO_record, final_GEO_record, self.total_tokens_in, self.total_tokens_out
|
762 |
+
|
763 |
+
def send_to_LLM_worker(self, path_to_crop, queue, model_name, MODEL_NAME_FORMATTED, name_parts, is_azure, i):
|
764 |
+
llm_model = self.initialize_llm_model(self.cfg, self.logger, MODEL_NAME_FORMATTED, self.JSON_dict_structure, name_parts, is_azure, self.llm, self.config_vals_for_permutation)
|
765 |
+
|
766 |
+
# self.update_progress_report_batch(progress_report, i)
|
767 |
+
|
768 |
+
if self.should_skip_specimen(path_to_crop):
|
769 |
+
self.log_skipping_specimen(path_to_crop)
|
770 |
+
return
|
771 |
+
|
772 |
+
paths = self.generate_paths(path_to_crop, i)
|
773 |
+
self.path_to_crop = path_to_crop
|
774 |
+
|
775 |
+
filename_without_extension, txt_file_path, txt_file_path_OCR, txt_file_path_OCR_bounds, jpg_file_path_OCR_helper, json_file_path_wiki, txt_file_path_ind_prompt = paths
|
776 |
+
# if json_report:
|
777 |
+
# json_report.set_text(text_main='Starting OCR')
|
778 |
+
self.perform_OCR_and_save_results(i, None, jpg_file_path_OCR_helper, txt_file_path_OCR, txt_file_path_OCR_bounds)
|
779 |
+
# if json_report:
|
780 |
+
# json_report.set_text(text_main='Finished OCR')
|
781 |
+
|
782 |
+
if not self.OCR:
|
783 |
+
self.n_failed_OCR += 1
|
784 |
+
response_candidate = None
|
785 |
+
nt_in = 0
|
786 |
+
nt_out = 0
|
787 |
+
else:
|
788 |
+
### Format prompt
|
789 |
+
prompt = self.setup_prompt()
|
790 |
+
# prompt = remove_colons_and_double_apostrophes(prompt) # This is moved to utils_VV since it broke the json structure.
|
791 |
+
|
792 |
+
### Send prompt to chosen LLM
|
793 |
+
self.logger.info(f'Waiting for {model_name} API call --- Using {MODEL_NAME_FORMATTED}')
|
794 |
+
|
795 |
+
if 'PALM2' in name_parts:
|
796 |
+
response_candidate, nt_in, nt_out, WFO_record, GEO_record, usage_report = llm_model.call_llm_api_GooglePalm2(prompt, None, paths)
|
797 |
+
|
798 |
+
elif 'GEMINI' in name_parts:
|
799 |
+
response_candidate, nt_in, nt_out, WFO_record, GEO_record, usage_report = llm_model.call_llm_api_GoogleGemini(prompt, None, paths)
|
800 |
+
|
801 |
+
elif 'MISTRAL' in name_parts and ('LOCAL' not in name_parts):
|
802 |
+
response_candidate, nt_in, nt_out, WFO_record, GEO_record, usage_report = llm_model.call_llm_api_MistralAI(prompt, None, paths)
|
803 |
+
|
804 |
+
elif 'LOCAL' in name_parts:
|
805 |
+
if 'MISTRAL' in name_parts or 'MIXTRAL' in name_parts:
|
806 |
+
if 'CPU' in name_parts:
|
807 |
+
response_candidate, nt_in, nt_out, WFO_record, GEO_record, usage_report = llm_model.call_llm_local_cpu_MistralAI(prompt, None, paths)
|
808 |
+
else:
|
809 |
+
response_candidate, nt_in, nt_out, WFO_record, GEO_record, usage_report = llm_model.call_llm_local_MistralAI(prompt, None, paths)
|
810 |
+
else:
|
811 |
+
response_candidate, nt_in, nt_out, WFO_record, GEO_record, usage_report = llm_model.call_llm_api_OpenAI(prompt, None, paths)
|
812 |
+
|
813 |
+
# Instead of directly updating shared resources, put the structured result in the queue
|
814 |
+
queue.put((response_candidate, nt_in, nt_out, WFO_record, GEO_record, usage_report, path_to_crop, paths))
|
815 |
+
|
816 |
+
##################################################################################################################################
|
817 |
+
################################################## LLM Helper Funcs ##############################################################
|
818 |
+
##################################################################################################################################
|
819 |
+
def initialize_llm_model(self, cfg, logger, model_name, JSON_dict_structure, name_parts, is_azure=None, llm_object=None, config_vals_for_permutation=None):
|
820 |
+
if 'LOCAL'in name_parts:
|
821 |
+
if ('MIXTRAL' in name_parts) or ('MISTRAL' in name_parts):
|
822 |
+
if 'CPU' in name_parts:
|
823 |
+
return LocalCPUMistralHandler(cfg, logger, model_name, JSON_dict_structure, config_vals_for_permutation)
|
824 |
+
else:
|
825 |
+
return LocalMistralHandler(cfg, logger, model_name, JSON_dict_structure, config_vals_for_permutation)
|
826 |
+
else:
|
827 |
+
if 'PALM2' in name_parts:
|
828 |
+
return GooglePalm2Handler(cfg, logger, model_name, JSON_dict_structure, config_vals_for_permutation)
|
829 |
+
elif 'GEMINI' in name_parts:
|
830 |
+
return GoogleGeminiHandler(cfg, logger, model_name, JSON_dict_structure, config_vals_for_permutation)
|
831 |
+
elif 'MISTRAL' in name_parts and ('LOCAL' not in name_parts):
|
832 |
+
return MistralHandler(cfg, logger, model_name, JSON_dict_structure, config_vals_for_permutation)
|
833 |
+
else:
|
834 |
+
return OpenAIHandler(cfg, logger, model_name, JSON_dict_structure, is_azure, llm_object, config_vals_for_permutation)
|
835 |
+
|
836 |
+
def setup_prompt(self):
|
837 |
+
Catalog = PromptCatalog()
|
838 |
+
prompt, _ = Catalog.prompt_SLTP(self.path_custom_prompts, OCR=self.OCR)
|
839 |
+
return prompt
|
840 |
+
|
841 |
+
def setup_JSON_dict_structure(self):
|
842 |
+
Catalog = PromptCatalog()
|
843 |
+
_, self.JSON_dict_structure = Catalog.prompt_SLTP(self.path_custom_prompts, OCR='Text')
|
844 |
+
|
845 |
+
|
846 |
+
def initialize_token_counters(self):
|
847 |
+
self.total_tokens_in = 0
|
848 |
+
self.total_tokens_out = 0
|
849 |
+
|
850 |
+
|
851 |
+
def update_progress_report_initial(self, progress_report):
|
852 |
+
if progress_report is not None:
|
853 |
+
progress_report.set_n_batches(len(self.img_paths))
|
854 |
+
|
855 |
+
|
856 |
+
def update_progress_report_batch(self, progress_report, batch_index):
|
857 |
+
if progress_report is not None:
|
858 |
+
progress_report.update_batch(f"Working on image {batch_index + 1} of {len(self.img_paths)}")
|
859 |
+
|
860 |
+
|
861 |
+
def should_skip_specimen(self, path_to_crop):
|
862 |
+
return os.path.basename(path_to_crop) in self.completed_specimens
|
863 |
+
|
864 |
+
|
865 |
+
def log_skipping_specimen(self, path_to_crop):
|
866 |
+
self.logger.info(f'[Skipping] specimen {os.path.basename(path_to_crop)} already processed')
|
867 |
+
|
868 |
+
|
869 |
+
def update_token_counters(self, nt_in, nt_out):
|
870 |
+
self.total_tokens_in += nt_in
|
871 |
+
self.total_tokens_out += nt_out
|
872 |
+
|
873 |
+
|
874 |
+
def update_final_response(self, response_candidate, WFO_record, GEO_record, usage_report, MODEL_NAME_FORMATTED, paths, path_to_crop, nt_in, nt_out):
|
875 |
+
filename_without_extension, txt_file_path, txt_file_path_OCR, txt_file_path_OCR_bounds, jpg_file_path_OCR_helper, json_file_path_wiki, txt_file_path_ind_prompt = paths
|
876 |
+
# Saving the JSON and XLSX files with the response and updating the final JSON response
|
877 |
+
if response_candidate is not None:
|
878 |
+
final_JSON_response_updated = self.save_json_and_xlsx(self.Dirs, response_candidate, WFO_record, GEO_record, usage_report, MODEL_NAME_FORMATTED, filename_without_extension, path_to_crop, txt_file_path, jpg_file_path_OCR_helper, nt_in, nt_out)
|
879 |
+
return final_JSON_response_updated, WFO_record, GEO_record
|
880 |
+
else:
|
881 |
+
final_JSON_response_updated = self.save_json_and_xlsx(self.Dirs, response_candidate, WFO_record, GEO_record, usage_report, MODEL_NAME_FORMATTED, filename_without_extension, path_to_crop, txt_file_path, jpg_file_path_OCR_helper, nt_in, nt_out)
|
882 |
+
return final_JSON_response_updated, WFO_record, GEO_record
|
883 |
+
|
884 |
+
|
885 |
+
def update_progress_report_final(self, progress_report):
|
886 |
+
if progress_report is not None:
|
887 |
+
progress_report.reset_batch("Batch Complete")
|
888 |
+
|
889 |
+
|
890 |
+
def parse_final_json_response(self, final_JSON_response):
|
891 |
+
try:
|
892 |
+
return json.loads(final_JSON_response.strip('```').replace('json\n', '', 1).replace('json', '', 1))
|
893 |
+
except:
|
894 |
+
return final_JSON_response
|
895 |
+
|
896 |
+
|
897 |
+
|
898 |
+
def generate_paths(self, path_to_crop, i):
|
899 |
+
filename_without_extension = os.path.splitext(os.path.basename(path_to_crop))[0]
|
900 |
+
txt_file_path = os.path.join(self.Dirs.transcription_ind, filename_without_extension + '.json')
|
901 |
+
txt_file_path_OCR = os.path.join(self.Dirs.transcription_ind_OCR, filename_without_extension + '.json')
|
902 |
+
txt_file_path_OCR_bounds = os.path.join(self.Dirs.transcription_ind_OCR_bounds, filename_without_extension + '.json')
|
903 |
+
jpg_file_path_OCR_helper = os.path.join(self.Dirs.transcription_ind_OCR_helper, filename_without_extension + '.jpg')
|
904 |
+
json_file_path_wiki = os.path.join(self.Dirs.transcription_ind_wiki, filename_without_extension + '.json')
|
905 |
+
txt_file_path_ind_prompt = os.path.join(self.Dirs.transcription_ind_prompt, filename_without_extension + '.txt')
|
906 |
+
|
907 |
+
self.logger.info(f'Working on {i+1}/{len(self.img_paths)} --- {filename_without_extension}')
|
908 |
+
|
909 |
+
return filename_without_extension, txt_file_path, txt_file_path_OCR, txt_file_path_OCR_bounds, jpg_file_path_OCR_helper, json_file_path_wiki, txt_file_path_ind_prompt
|
910 |
+
|
911 |
+
|
912 |
+
def save_json_and_xlsx(self, Dirs, response, WFO_record, GEO_record, usage_report,
|
913 |
+
MODEL_NAME_FORMATTED, filename_without_extension, path_to_crop, txt_file_path, jpg_file_path_OCR_helper, nt_in, nt_out):
|
914 |
+
if response is None:
|
915 |
+
response = self.JSON_dict_structure
|
916 |
+
# Insert 'filename' as the first key
|
917 |
+
response = {'filename': filename_without_extension, **{k: v for k, v in response.items() if k != 'filename'}}
|
918 |
+
self.write_json_to_file(txt_file_path, response)
|
919 |
+
|
920 |
+
# Then add the null info to the spreadsheet
|
921 |
+
response_null = self.create_null_row(filename_without_extension, path_to_crop, txt_file_path, jpg_file_path_OCR_helper)
|
922 |
+
self.add_data_to_excel_from_response(Dirs, self.path_transcription, response_null, WFO_record, GEO_record, usage_report, MODEL_NAME_FORMATTED, filename_without_extension, path_to_crop, txt_file_path, jpg_file_path_OCR_helper, nt_in=0, nt_out=0)
|
923 |
+
|
924 |
+
### Set completed JSON
|
925 |
+
else:
|
926 |
+
response = self.clean_catalog_number(response, filename_without_extension)
|
927 |
+
self.write_json_to_file(txt_file_path, response)
|
928 |
+
# add to the xlsx file
|
929 |
+
self.add_data_to_excel_from_response(Dirs, self.path_transcription, response, WFO_record, GEO_record, usage_report, MODEL_NAME_FORMATTED, filename_without_extension, path_to_crop, txt_file_path, jpg_file_path_OCR_helper, nt_in, nt_out)
|
930 |
+
return response
|
931 |
+
|
932 |
+
|
933 |
+
def process_specimen_batch(self, progress_report, json_report, is_real_run=False):
|
934 |
+
if not self.has_key:
|
935 |
+
self.logger.error(f'No API key found for {self.version_name}')
|
936 |
+
raise Exception(f"No API key found for {self.version_name}")
|
937 |
+
|
938 |
+
try:
|
939 |
+
if is_real_run:
|
940 |
+
progress_report.update_overall(f"Transcribing Labels")
|
941 |
+
|
942 |
+
final_json_response, final_WFO_record, final_GEO_record, total_tokens_in, total_tokens_out = self.send_to_LLM(self.is_azure, progress_report, json_report, self.model_name)
|
943 |
+
|
944 |
+
return final_json_response, final_WFO_record, final_GEO_record, total_tokens_in, total_tokens_out
|
945 |
+
|
946 |
+
except Exception as e:
|
947 |
+
self.logger.error(f"LLM call failed in process_specimen_batch: {e}")
|
948 |
+
if progress_report is not None:
|
949 |
+
progress_report.reset_batch(f"Batch Failed")
|
950 |
+
self.close_logger_handlers()
|
951 |
+
raise
|
952 |
+
|
953 |
+
|
954 |
+
def close_logger_handlers(self):
|
955 |
+
for handler in self.logger.handlers[:]:
|
956 |
+
handler.close()
|
957 |
+
self.logger.removeHandler(handler)
|
958 |
+
|
959 |
+
|
960 |
+
# def process_specimen_batch_OCR_test(self, path_to_crop):
|
961 |
+
# for img_filename in os.listdir(path_to_crop):
|
962 |
+
# img_path = os.path.join(path_to_crop, img_filename)
|
963 |
+
# self.OCR, self.bounds, self.text_to_box_mapping = detect_text(img_path)
|
964 |
+
|
965 |
+
# https://gist.github.com/benhoyt/8c8a8d62debe8e5aa5340373f9c509c7
|
966 |
+
class AtomicCounter(object):
|
967 |
+
"""An atomic, thread-safe counter"""
|
968 |
+
|
969 |
+
def __init__(self, initial=0):
|
970 |
+
"""Initialize a new atomic counter to given initial value"""
|
971 |
+
self._value = initial
|
972 |
+
self._lock = threading.Lock()
|
973 |
+
|
974 |
+
def inc(self, num=1):
|
975 |
+
"""Atomically increment the counter by num and return the new value"""
|
976 |
+
with self._lock:
|
977 |
+
self._value += num
|
978 |
+
return self._value
|
979 |
+
|
980 |
+
def dec(self, num=1):
|
981 |
+
"""Atomically decrement the counter by num and return the new value"""
|
982 |
+
with self._lock:
|
983 |
+
self._value -= num
|
984 |
+
return self._value
|
985 |
+
|
986 |
+
@property
|
987 |
+
def value(self):
|
988 |
+
return self._value
|
989 |
+
|
990 |
+
|
991 |
+
def space_saver(cfg, Dirs, logger):
|
992 |
+
dir_out = cfg['leafmachine']['project']['dir_output']
|
993 |
+
run_name = Dirs.run_name
|
994 |
+
|
995 |
+
path_project = os.path.join(dir_out, run_name)
|
996 |
+
|
997 |
+
if cfg['leafmachine']['project']['delete_temps_keep_VVE']:
|
998 |
+
logger.name = '[DELETE TEMP FILES]'
|
999 |
+
logger.info("Deleting temporary files. Keeping files required for VoucherVisionEditor.")
|
1000 |
+
delete_dirs = ['Archival_Components', 'Config_File']
|
1001 |
+
for d in delete_dirs:
|
1002 |
+
path_delete = os.path.join(path_project, d)
|
1003 |
+
if os.path.exists(path_delete):
|
1004 |
+
shutil.rmtree(path_delete)
|
1005 |
+
|
1006 |
+
elif cfg['leafmachine']['project']['delete_all_temps']:
|
1007 |
+
logger.name = '[DELETE TEMP FILES]'
|
1008 |
+
logger.info("Deleting ALL temporary files!")
|
1009 |
+
delete_dirs = ['Archival_Components', 'Config_File', 'Original_Images', 'Cropped_Images']
|
1010 |
+
for d in delete_dirs:
|
1011 |
+
path_delete = os.path.join(path_project, d)
|
1012 |
+
if os.path.exists(path_delete):
|
1013 |
+
shutil.rmtree(path_delete)
|
1014 |
+
|
1015 |
+
# Delete the transctiption folder, but keep the xlsx
|
1016 |
+
transcription_path = os.path.join(path_project, 'Transcription')
|
1017 |
+
if os.path.exists(transcription_path):
|
1018 |
+
for item in os.listdir(transcription_path):
|
1019 |
+
item_path = os.path.join(transcription_path, item)
|
1020 |
+
if os.path.isdir(item_path): # if the item is a directory
|
1021 |
+
if os.path.exists(item_path):
|
1022 |
+
shutil.rmtree(item_path) # delete the directory
|
vouchervision/vouchervision_main.py
CHANGED
@@ -14,6 +14,7 @@ from vouchervision.data_project import Project_Info
|
|
14 |
from vouchervision.LM2_logger import start_logging
|
15 |
from vouchervision.fetch_data import fetch_data
|
16 |
from vouchervision.utils_VoucherVision import VoucherVision, space_saver
|
|
|
17 |
from vouchervision.utils_hf import upload_to_drive
|
18 |
|
19 |
def voucher_vision(cfg_file_path, dir_home, path_custom_prompts, cfg_test, progress_report, json_report, path_api_cost=None, test_ind = None, is_hf = True, is_real_run=False):
|
|
|
14 |
from vouchervision.LM2_logger import start_logging
|
15 |
from vouchervision.fetch_data import fetch_data
|
16 |
from vouchervision.utils_VoucherVision import VoucherVision, space_saver
|
17 |
+
# from vouchervision.utils_VoucherVision_parallel import VoucherVision, space_saver
|
18 |
from vouchervision.utils_hf import upload_to_drive
|
19 |
|
20 |
def voucher_vision(cfg_file_path, dir_home, path_custom_prompts, cfg_test, progress_report, json_report, path_api_cost=None, test_ind = None, is_hf = True, is_real_run=False):
|
vouchervision/vouchervision_test_all_options_analysis.py
ADDED
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pandas as pd
|
2 |
+
import matplotlib.pyplot as plt
|
3 |
+
import seaborn as sns
|
4 |
+
|
5 |
+
def SUMMARY_permute_llms_to_sweep_temperature_and_topP_for_GPT4_SHORT():
|
6 |
+
#####################
|
7 |
+
# Load the Excel file
|
8 |
+
file_path = 'D:/Dropbox/VoucherVision/demo/validation_output/summary/SUMMARY_permute_llms_to_sweep_temperature_and_topP_for_GPT4_SHORT.xlsx'
|
9 |
+
save_path = 'D:/Dropbox/VoucherVision/demo/validation_output/figures/avg_L_score_analysis_SUMMARY_permute_llms_to_sweep_temperature_and_topP_for_GPT4_SHORT.png'
|
10 |
+
|
11 |
+
df = pd.read_excel(file_path)
|
12 |
+
|
13 |
+
# Display the first few rows of the dataframe to understand its structure
|
14 |
+
df.head()
|
15 |
+
|
16 |
+
# Grouping by the parameters and calculating the mean of avg_L_score for each group
|
17 |
+
grouped = df.groupby(['v_prompt_version', 'v_double_ocr', 'temperature', 'top_p'])['avg_L_score'].mean().reset_index()
|
18 |
+
|
19 |
+
# Finding the group with the highest average L score
|
20 |
+
max_avg_L_score = grouped['avg_L_score'].max()
|
21 |
+
best_group = grouped[grouped['avg_L_score'] == max_avg_L_score]
|
22 |
+
|
23 |
+
print(best_group)
|
24 |
+
|
25 |
+
|
26 |
+
### Viz
|
27 |
+
# Filtering the dataset for the conditions mentioned
|
28 |
+
filtered_df = df[df['v_prompt_version'] == 'SLTPvB_long.yaml'][df['v_double_ocr'] == True]
|
29 |
+
|
30 |
+
# Setting up the plotting
|
31 |
+
plt.figure(figsize=(14, 6))
|
32 |
+
|
33 |
+
# Plot 1: avg_L_score as a function of temperature for each top_p value
|
34 |
+
plt.subplot(1, 2, 1)
|
35 |
+
sns.lineplot(data=filtered_df, x='temperature', y='avg_L_score', hue='top_p', marker='o')
|
36 |
+
plt.title('Average L Score by Temperature for each Top P')
|
37 |
+
plt.xlabel('Temperature')
|
38 |
+
plt.ylabel('Average L Score')
|
39 |
+
plt.legend(title='Top P', bbox_to_anchor=(1.05, 1), loc='upper left')
|
40 |
+
|
41 |
+
# Plot 2: avg_L_score as a function of top_p for each temperature value
|
42 |
+
plt.subplot(1, 2, 2)
|
43 |
+
sns.lineplot(data=filtered_df, x='top_p', y='avg_L_score', hue='temperature', marker='o')
|
44 |
+
plt.title('Average L Score by Top P for each Temperature')
|
45 |
+
plt.xlabel('Top P')
|
46 |
+
plt.ylabel('Average L Score')
|
47 |
+
plt.legend(title='Temperature', bbox_to_anchor=(1.05, 1), loc='upper left')
|
48 |
+
|
49 |
+
plt.tight_layout()
|
50 |
+
plt.savefig(save_path, dpi=600)
|
51 |
+
|
52 |
+
def SUMMARY_permute_llms_to_sweep_temperature_and_topP_for_google_SHORT():
|
53 |
+
#####################
|
54 |
+
# Load the Excel file
|
55 |
+
file_path = 'D:/Dropbox/VoucherVision/demo/validation_output/summary/SUMMARY_permute_llms_to_sweep_temperature_and_topP_for_google_SHORT.xlsx'
|
56 |
+
save_path = 'D:/Dropbox/VoucherVision/demo/validation_output/figures/avg_L_score_analysis_SUMMARY_permute_llms_to_sweep_temperature_and_topP_for_google_SHORT.png'
|
57 |
+
|
58 |
+
df = pd.read_excel(file_path)
|
59 |
+
|
60 |
+
# Display the first few rows of the dataframe to understand its structure
|
61 |
+
df.head()
|
62 |
+
|
63 |
+
# Grouping by the parameters and calculating the mean of avg_L_score for each group
|
64 |
+
grouped = df.groupby(['v_prompt_version', 'v_double_ocr', 'temperature', 'top_p'])['avg_L_score'].mean().reset_index()
|
65 |
+
|
66 |
+
# Finding the group with the highest average L score
|
67 |
+
max_avg_L_score = grouped['avg_L_score'].max()
|
68 |
+
best_group = grouped[grouped['avg_L_score'] == max_avg_L_score]
|
69 |
+
|
70 |
+
print(best_group)
|
71 |
+
|
72 |
+
|
73 |
+
### Viz
|
74 |
+
# Filtering the dataset for the conditions mentioned
|
75 |
+
filtered_df = df[df['v_prompt_version'] == 'SLTPvB_long.yaml'][df['v_double_ocr'] == True]
|
76 |
+
|
77 |
+
# Setting up the plotting
|
78 |
+
plt.figure(figsize=(14, 6))
|
79 |
+
|
80 |
+
# Plot 1: avg_L_score as a function of temperature for each top_p value
|
81 |
+
plt.subplot(1, 2, 1)
|
82 |
+
sns.lineplot(data=filtered_df, x='temperature', y='avg_L_score', hue='top_p', marker='o')
|
83 |
+
plt.title('Average L Score by Temperature for each Top P')
|
84 |
+
plt.xlabel('Temperature')
|
85 |
+
plt.ylabel('Average L Score')
|
86 |
+
plt.legend(title='Top P', bbox_to_anchor=(1.05, 1), loc='upper left')
|
87 |
+
|
88 |
+
# Plot 2: avg_L_score as a function of top_p for each temperature value
|
89 |
+
plt.subplot(1, 2, 2)
|
90 |
+
sns.lineplot(data=filtered_df, x='top_p', y='avg_L_score', hue='temperature', marker='o')
|
91 |
+
plt.title('Average L Score by Top P for each Temperature')
|
92 |
+
plt.xlabel('Top P')
|
93 |
+
plt.ylabel('Average L Score')
|
94 |
+
plt.legend(title='Temperature', bbox_to_anchor=(1.05, 1), loc='upper left')
|
95 |
+
|
96 |
+
plt.tight_layout()
|
97 |
+
plt.savefig(save_path, dpi=600)
|
98 |
+
|
99 |
+
if __name__ == '__main__':
|
100 |
+
# SUMMARY_permute_llms_to_sweep_temperature_and_topP_for_GPT4_SHORT()
|
101 |
+
SUMMARY_permute_llms_to_sweep_temperature_and_topP_for_google_SHORT()
|
102 |
+
|
103 |
+
|
vouchervision/vouchervision_test_all_options_recipes.py
ADDED
@@ -0,0 +1,170 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os, inspect, sys, shutil
|
2 |
+
|
3 |
+
|
4 |
+
|
5 |
+
class AllOptions():
|
6 |
+
a_llm = [
|
7 |
+
"GPT 4 Turbo 1106-preview",
|
8 |
+
"GPT 4 Turbo 0125-preview",
|
9 |
+
'GPT 4',
|
10 |
+
'GPT 4 32k',
|
11 |
+
'GPT 3.5',
|
12 |
+
'GPT 3.5 Instruct',
|
13 |
+
|
14 |
+
'Azure GPT 3.5',
|
15 |
+
'Azure GPT 3.5 Instruct',
|
16 |
+
'Azure GPT 4',
|
17 |
+
'Azure GPT 4 Turbo 1106-preview',
|
18 |
+
'Azure GPT 4 Turbo 0125-preview',
|
19 |
+
'Azure GPT 4 32k',
|
20 |
+
|
21 |
+
'PaLM 2 text-bison@001',
|
22 |
+
'PaLM 2 text-bison@002',
|
23 |
+
'PaLM 2 text-unicorn@001',
|
24 |
+
'Gemini Pro',
|
25 |
+
|
26 |
+
'Mistral Small',
|
27 |
+
'Mistral Medium',
|
28 |
+
'Mistral Large',
|
29 |
+
'Open Mixtral 8x7B',
|
30 |
+
'Open Mistral 7B',
|
31 |
+
|
32 |
+
'LOCAL Mixtral 8x7B Instruct v0.1',
|
33 |
+
'LOCAL Mistral 7B Instruct v0.2',
|
34 |
+
|
35 |
+
'LOCAL CPU Mistral 7B Instruct v0.2 GGUF',
|
36 |
+
]
|
37 |
+
|
38 |
+
a_prompt_version = [
|
39 |
+
'SLTPvA_long.yaml',
|
40 |
+
'SLTPvA_medium.yaml',
|
41 |
+
'SLTPvA_short.yaml',
|
42 |
+
'SLTPvB_long.yaml',
|
43 |
+
'SLTPvB_medium.yaml',
|
44 |
+
'SLTPvB_short.yaml',
|
45 |
+
'SLTPvB_minimal.yaml',
|
46 |
+
]
|
47 |
+
|
48 |
+
a_LM2 = [False,] # [True, False]
|
49 |
+
a_do_use_trOCR = [False,] # [True, False]
|
50 |
+
a_trocr_path = ["microsoft/trocr-large-handwritten",]
|
51 |
+
a_ocr_option = [
|
52 |
+
'hand',
|
53 |
+
'normal',
|
54 |
+
'CRAFT',
|
55 |
+
'LLaVA',
|
56 |
+
['hand','CRAFT'],
|
57 |
+
['hand','LLaVA'],
|
58 |
+
]
|
59 |
+
a_llava_option = ["llava-v1.6-mistral-7b",
|
60 |
+
"llava-v1.6-34b",
|
61 |
+
"llava-v1.6-vicuna-13b",
|
62 |
+
"llava-v1.6-vicuna-7b",]
|
63 |
+
a_llava_bit = ["full", "4bit",]
|
64 |
+
a_double_ocr = [True, False]
|
65 |
+
|
66 |
+
|
67 |
+
|
68 |
+
|
69 |
+
class Options_permute_llms_to_investigate_determinism_at_restrictive_settings():
|
70 |
+
a_llm = [
|
71 |
+
# "GPT 4 Turbo 1106-preview",
|
72 |
+
# "GPT 4 Turbo 0125-preview",
|
73 |
+
# 'GPT 4',
|
74 |
+
# # 'GPT 4 32k',
|
75 |
+
# 'GPT 3.5 Turbo',
|
76 |
+
# 'GPT 3.5 Instruct',
|
77 |
+
|
78 |
+
'Azure GPT 3.5 Turbo',
|
79 |
+
'Azure GPT 3.5 Instruct',
|
80 |
+
'Azure GPT 4',
|
81 |
+
'Azure GPT 4 Turbo 1106-preview',
|
82 |
+
'Azure GPT 4 Turbo 0125-preview',
|
83 |
+
# 'Azure GPT 4 32k',
|
84 |
+
|
85 |
+
'PaLM 2 text-bison@001',
|
86 |
+
'PaLM 2 text-bison@002',
|
87 |
+
'PaLM 2 text-unicorn@001',
|
88 |
+
'Gemini Pro',
|
89 |
+
|
90 |
+
'Mistral Small',
|
91 |
+
'Mistral Medium',
|
92 |
+
'Mistral Large',
|
93 |
+
# 'Open Mixtral 8x7B',
|
94 |
+
'Open Mistral 7B',
|
95 |
+
|
96 |
+
# 'LOCAL Mixtral 8x7B Instruct v0.1',
|
97 |
+
# 'LOCAL Mistral 7B Instruct v0.2',
|
98 |
+
|
99 |
+
# 'LOCAL CPU Mistral 7B Instruct v0.2 GGUF',
|
100 |
+
]
|
101 |
+
|
102 |
+
a_prompt_version = [
|
103 |
+
# 'SLTPvA_long.yaml',
|
104 |
+
# 'SLTPvA_short.yaml',
|
105 |
+
'SLTPvB_long.yaml',
|
106 |
+
'SLTPvB_short.yaml',
|
107 |
+
'SLTPvB_minimal.yaml',
|
108 |
+
]
|
109 |
+
a_double_ocr = [True, False]
|
110 |
+
|
111 |
+
### BELOW ARE STATIC
|
112 |
+
a_LM2 = [False,]
|
113 |
+
# a_do_use_trOCR = [True, False]
|
114 |
+
a_do_use_trOCR = [False,]
|
115 |
+
# a_trocr_path = ["microsoft/trocr-large-handwritten","microsoft/trocr-base-handwritten",]
|
116 |
+
a_trocr_path = ["microsoft/trocr-large-handwritten",]
|
117 |
+
a_ocr_option = ['hand',]
|
118 |
+
a_llava_option = ["llava-v1.6-mistral-7b",]
|
119 |
+
a_llava_bit = ["full",]
|
120 |
+
|
121 |
+
|
122 |
+
class Options_permute_llms_to_sweep_temperature_and_topP_for_GPT4_0125():
|
123 |
+
a_llm = [
|
124 |
+
# 'Azure GPT 4 Turbo 0125-preview', #test 1
|
125 |
+
'Azure GPT 4',
|
126 |
+
]
|
127 |
+
|
128 |
+
a_prompt_version = [
|
129 |
+
# 'SLTPvA_long.yaml',
|
130 |
+
# 'SLTPvA_short.yaml',
|
131 |
+
'SLTPvB_long.yaml',
|
132 |
+
'SLTPvB_short.yaml',
|
133 |
+
# 'SLTPvB_minimal.yaml',
|
134 |
+
]
|
135 |
+
a_double_ocr = [True, False]
|
136 |
+
|
137 |
+
### BELOW ARE STATIC
|
138 |
+
a_LM2 = [False,]
|
139 |
+
# a_do_use_trOCR = [True, False]
|
140 |
+
a_do_use_trOCR = [False,]
|
141 |
+
# a_trocr_path = ["microsoft/trocr-large-handwritten","microsoft/trocr-base-handwritten",]
|
142 |
+
a_trocr_path = ["microsoft/trocr-large-handwritten",]
|
143 |
+
a_ocr_option = ['hand',]
|
144 |
+
a_llava_option = ["llava-v1.6-mistral-7b",]
|
145 |
+
a_llava_bit = ["full",]
|
146 |
+
|
147 |
+
|
148 |
+
class Options_permute_llms_to_sweep_temperature_and_topP_for_google():
|
149 |
+
a_llm = [
|
150 |
+
'PaLM 2 text-bison@001',
|
151 |
+
'PaLM 2 text-bison@002',
|
152 |
+
'Gemini Pro',
|
153 |
+
]
|
154 |
+
|
155 |
+
a_prompt_version = [
|
156 |
+
'SLTPvB_long.yaml',
|
157 |
+
'SLTPvB_short.yaml',
|
158 |
+
]
|
159 |
+
a_double_ocr = [True, False]
|
160 |
+
|
161 |
+
### BELOW ARE STATIC
|
162 |
+
a_LM2 = [False,]
|
163 |
+
a_do_use_trOCR = [False,] # [True, False]
|
164 |
+
a_trocr_path = ["microsoft/trocr-large-handwritten",]
|
165 |
+
a_ocr_option = ['hand',]
|
166 |
+
a_llava_option = ["llava-v1.6-mistral-7b",]
|
167 |
+
a_llava_bit = ["full",]
|
168 |
+
|
169 |
+
if __name__ == '__main__':
|
170 |
+
pass
|