Spaces:
Running
Running
phyloforfun
commited on
Commit
·
9d06861
1
Parent(s):
4d14f52
Major update. Support for 15 LLMs, World Flora Online taxonomy validation, geolocation, 2 OCR methods, significant UI changes, stability improvements, consistent JSON parsing
Browse files- app.py +6 -6
- custom_prompts/SLTPvA_long.yaml +22 -22
- custom_prompts/SLTPvA_medium.yaml +22 -22
- custom_prompts/SLTPvA_short.yaml +22 -22
- requirements.txt +0 -0
- vouchervision/LLM_GoogleGemini.py +18 -9
- vouchervision/LLM_GooglePalm2.py +18 -9
- vouchervision/LLM_MistralAI.py +20 -10
- vouchervision/LLM_OpenAI.py +19 -10
- vouchervision/LLM_local_MistralAI.py +21 -11
- vouchervision/LLM_local_cpu_MistralAI.py +20 -10
- vouchervision/directory_structure_VV.py +10 -0
- vouchervision/prompt_catalog.py +18 -1
- vouchervision/tool_wikipedia.py +581 -0
- vouchervision/utils_LLM.py +47 -3
- vouchervision/utils_LLM_JSON_validation.py +7 -5
- vouchervision/utils_VoucherVision.py +35 -19
- vouchervision/utils_hf.py +55 -37
- vouchervision/vouchervision_main.py +18 -10
app.py
CHANGED
@@ -27,7 +27,7 @@ st.set_page_config(layout="wide", page_icon='img/icon.ico', page_title='VoucherV
|
|
27 |
|
28 |
# Parse the 'is_hf' argument and set it in session state
|
29 |
if 'is_hf' not in st.session_state:
|
30 |
-
st.session_state['is_hf'] =
|
31 |
|
32 |
|
33 |
########################################################################################################
|
@@ -141,7 +141,8 @@ def content_input_images(col_left, col_right):
|
|
141 |
pass
|
142 |
elif not st.session_state['view_local_gallery'] and not st.session_state['input_list_small'] and (st.session_state['dir_images_local_TEMP'] == st.session_state.config['leafmachine']['project']['dir_images_local']):
|
143 |
pass
|
144 |
-
elif st.session_state['input_list_small'] and (st.session_state['dir_images_local_TEMP'] != st.session_state.config['leafmachine']['project']['dir_images_local']):
|
|
|
145 |
dir_images_local = st.session_state.config['leafmachine']['project']['dir_images_local']
|
146 |
count_n_imgs = list_jpg_files(dir_images_local)
|
147 |
st.session_state['processing_add_on'] = count_n_imgs
|
@@ -1012,8 +1013,6 @@ def save_prompt_yaml(filename, col):
|
|
1012 |
|
1013 |
st.success(f"Prompt saved as '{filename}.yaml'.")
|
1014 |
|
1015 |
-
upload_to_drive(filepath, filename) # added
|
1016 |
-
|
1017 |
with col: # added
|
1018 |
create_download_button_yaml(filepath, filename,key_val=2456237465) # added
|
1019 |
|
@@ -1363,7 +1362,7 @@ def build_LLM_prompt_config():
|
|
1363 |
# This assumes that the column names are the keys in the dictionary under 'rules'
|
1364 |
all_column_names = list(st.session_state['rules'].keys())
|
1365 |
|
1366 |
-
categories = ['TAXONOMY', 'GEOGRAPHY', 'LOCALITY', 'COLLECTING', '
|
1367 |
if ('mapping' not in st.session_state) or (st.session_state['mapping'] == {}):
|
1368 |
st.session_state['mapping'] = {category: [] for category in categories}
|
1369 |
for category in categories:
|
@@ -1751,6 +1750,7 @@ def content_header():
|
|
1751 |
path_api_cost=os.path.join(st.session_state.dir_home,'api_cost','api_cost.yaml'),
|
1752 |
is_hf = st.session_state['is_hf'],
|
1753 |
is_real_run=True)
|
|
|
1754 |
st.balloons()
|
1755 |
except Exception as e:
|
1756 |
with col_run_4:
|
@@ -2020,7 +2020,7 @@ def content_collage_overlay():
|
|
2020 |
with col_collage:
|
2021 |
st.header('LeafMachine2 Label Collage')
|
2022 |
default_crops = st.session_state.config['leafmachine']['cropped_components']['save_cropped_annotations']
|
2023 |
-
st.write("Prior to transcription, use LeafMachine2 to crop all labels from input images to create label collages for each specimen image.
|
2024 |
st.session_state.config['leafmachine']['use_RGB_label_images'] = st.checkbox("Use LeafMachine2 label collage for transcriptions", st.session_state.config['leafmachine'].get('use_RGB_label_images', False))
|
2025 |
|
2026 |
|
|
|
27 |
|
28 |
# Parse the 'is_hf' argument and set it in session state
|
29 |
if 'is_hf' not in st.session_state:
|
30 |
+
st.session_state['is_hf'] = False
|
31 |
|
32 |
|
33 |
########################################################################################################
|
|
|
141 |
pass
|
142 |
elif not st.session_state['view_local_gallery'] and not st.session_state['input_list_small'] and (st.session_state['dir_images_local_TEMP'] == st.session_state.config['leafmachine']['project']['dir_images_local']):
|
143 |
pass
|
144 |
+
# elif st.session_state['input_list_small'] and (st.session_state['dir_images_local_TEMP'] != st.session_state.config['leafmachine']['project']['dir_images_local']):
|
145 |
+
elif (st.session_state['dir_images_local_TEMP'] != st.session_state.config['leafmachine']['project']['dir_images_local']):
|
146 |
dir_images_local = st.session_state.config['leafmachine']['project']['dir_images_local']
|
147 |
count_n_imgs = list_jpg_files(dir_images_local)
|
148 |
st.session_state['processing_add_on'] = count_n_imgs
|
|
|
1013 |
|
1014 |
st.success(f"Prompt saved as '{filename}.yaml'.")
|
1015 |
|
|
|
|
|
1016 |
with col: # added
|
1017 |
create_download_button_yaml(filepath, filename,key_val=2456237465) # added
|
1018 |
|
|
|
1362 |
# This assumes that the column names are the keys in the dictionary under 'rules'
|
1363 |
all_column_names = list(st.session_state['rules'].keys())
|
1364 |
|
1365 |
+
categories = ['TAXONOMY', 'GEOGRAPHY', 'LOCALITY', 'COLLECTING', 'MISC']
|
1366 |
if ('mapping' not in st.session_state) or (st.session_state['mapping'] == {}):
|
1367 |
st.session_state['mapping'] = {category: [] for category in categories}
|
1368 |
for category in categories:
|
|
|
1750 |
path_api_cost=os.path.join(st.session_state.dir_home,'api_cost','api_cost.yaml'),
|
1751 |
is_hf = st.session_state['is_hf'],
|
1752 |
is_real_run=True)
|
1753 |
+
|
1754 |
st.balloons()
|
1755 |
except Exception as e:
|
1756 |
with col_run_4:
|
|
|
2020 |
with col_collage:
|
2021 |
st.header('LeafMachine2 Label Collage')
|
2022 |
default_crops = st.session_state.config['leafmachine']['cropped_components']['save_cropped_annotations']
|
2023 |
+
st.write("Prior to transcription, use LeafMachine2 to crop all labels from input images to create label collages for each specimen image. Showing just the text labels to the OCR algorithms significantly improves performance. This runs slowly on the free Hugging Face Space, but runs quickly with a fast CPU or any GPU.")
|
2024 |
st.session_state.config['leafmachine']['use_RGB_label_images'] = st.checkbox("Use LeafMachine2 label collage for transcriptions", st.session_state.config['leafmachine'].get('use_RGB_label_images', False))
|
2025 |
|
2026 |
|
custom_prompts/SLTPvA_long.yaml
CHANGED
@@ -78,35 +78,35 @@ rules:
|
|
78 |
are explicit then convert from feet ("ft" or "ft." or "feet") to meters ("m"
|
79 |
or "m." or "meters"). Round to integer.
|
80 |
mapping:
|
81 |
-
|
82 |
-
-
|
83 |
-
-
|
84 |
-
-
|
85 |
-
-
|
86 |
-
-
|
|
|
|
|
|
|
|
|
87 |
GEOGRAPHY:
|
88 |
- country
|
89 |
- stateProvince
|
90 |
- county
|
91 |
- municipality
|
92 |
-
- minimumElevationInMeters
|
93 |
-
- maximumElevationInMeters
|
94 |
-
LOCALITY:
|
95 |
-
- locality
|
96 |
-
- habitat
|
97 |
- decimalLatitude
|
98 |
- decimalLongitude
|
99 |
- verbatimCoordinates
|
100 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
101 |
- degreeOfEstablishment
|
102 |
- occurrenceRemarks
|
103 |
-
|
104 |
-
- catalogNumber
|
105 |
-
- order
|
106 |
-
- family
|
107 |
-
- scientificName
|
108 |
-
- scientificNameAuthorship
|
109 |
-
- genus
|
110 |
-
- subgenus
|
111 |
-
- specificEpithet
|
112 |
-
- infraspecificEpithet
|
|
|
78 |
are explicit then convert from feet ("ft" or "ft." or "feet") to meters ("m"
|
79 |
or "m." or "meters"). Round to integer.
|
80 |
mapping:
|
81 |
+
TAXONOMY:
|
82 |
+
- catalogNumber
|
83 |
+
- order
|
84 |
+
- family
|
85 |
+
- scientificName
|
86 |
+
- scientificNameAuthorship
|
87 |
+
- genus
|
88 |
+
- subgenus
|
89 |
+
- specificEpithet
|
90 |
+
- infraspecificEpithet
|
91 |
GEOGRAPHY:
|
92 |
- country
|
93 |
- stateProvince
|
94 |
- county
|
95 |
- municipality
|
|
|
|
|
|
|
|
|
|
|
96 |
- decimalLatitude
|
97 |
- decimalLongitude
|
98 |
- verbatimCoordinates
|
99 |
+
LOCALITY:
|
100 |
+
- locality
|
101 |
+
- habitat
|
102 |
+
- minimumElevationInMeters
|
103 |
+
- maximumElevationInMeters
|
104 |
+
COLLECTING:
|
105 |
+
- identifiedBy
|
106 |
+
- recordedBy
|
107 |
+
- recordNumber
|
108 |
+
- verbatimEventDate
|
109 |
+
- eventDate
|
110 |
- degreeOfEstablishment
|
111 |
- occurrenceRemarks
|
112 |
+
MISC:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
custom_prompts/SLTPvA_medium.yaml
CHANGED
@@ -53,35 +53,35 @@ rules:
|
|
53 |
minimumElevationInMeters: Minimum elevation or altitude in meters. Only if units are explicit then convert from feet ("ft" or "ft."" or "feet") to meters ("m" or "m." or "meters"). Round to integer.
|
54 |
maximumElevationInMeters: Maximum elevation or altitude in meters. If only one elevation is present, then max_elevation should be set to the null_value. Only if units are explicit then convert from feet ("ft" or "ft." or "feet") to meters ("m" or "m." or "meters"). Round to integer.
|
55 |
mapping:
|
56 |
-
|
57 |
-
-
|
58 |
-
-
|
59 |
-
-
|
60 |
-
-
|
61 |
-
-
|
|
|
|
|
|
|
|
|
62 |
GEOGRAPHY:
|
63 |
- country
|
64 |
- stateProvince
|
65 |
- county
|
66 |
- municipality
|
67 |
-
- minimumElevationInMeters
|
68 |
-
- maximumElevationInMeters
|
69 |
-
LOCALITY:
|
70 |
-
- locality
|
71 |
-
- habitat
|
72 |
- decimalLatitude
|
73 |
- decimalLongitude
|
74 |
- verbatimCoordinates
|
75 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
76 |
- degreeOfEstablishment
|
77 |
- occurrenceRemarks
|
78 |
-
|
79 |
-
- catalogNumber
|
80 |
-
- order
|
81 |
-
- family
|
82 |
-
- scientificName
|
83 |
-
- scientificNameAuthorship
|
84 |
-
- genus
|
85 |
-
- subgenus
|
86 |
-
- specificEpithet
|
87 |
-
- infraspecificEpithet
|
|
|
53 |
minimumElevationInMeters: Minimum elevation or altitude in meters. Only if units are explicit then convert from feet ("ft" or "ft."" or "feet") to meters ("m" or "m." or "meters"). Round to integer.
|
54 |
maximumElevationInMeters: Maximum elevation or altitude in meters. If only one elevation is present, then max_elevation should be set to the null_value. Only if units are explicit then convert from feet ("ft" or "ft." or "feet") to meters ("m" or "m." or "meters"). Round to integer.
|
55 |
mapping:
|
56 |
+
TAXONOMY:
|
57 |
+
- catalogNumber
|
58 |
+
- order
|
59 |
+
- family
|
60 |
+
- scientificName
|
61 |
+
- scientificNameAuthorship
|
62 |
+
- genus
|
63 |
+
- subgenus
|
64 |
+
- specificEpithet
|
65 |
+
- infraspecificEpithet
|
66 |
GEOGRAPHY:
|
67 |
- country
|
68 |
- stateProvince
|
69 |
- county
|
70 |
- municipality
|
|
|
|
|
|
|
|
|
|
|
71 |
- decimalLatitude
|
72 |
- decimalLongitude
|
73 |
- verbatimCoordinates
|
74 |
+
LOCALITY:
|
75 |
+
- locality
|
76 |
+
- habitat
|
77 |
+
- minimumElevationInMeters
|
78 |
+
- maximumElevationInMeters
|
79 |
+
COLLECTING:
|
80 |
+
- identifiedBy
|
81 |
+
- recordedBy
|
82 |
+
- recordNumber
|
83 |
+
- verbatimEventDate
|
84 |
+
- eventDate
|
85 |
- degreeOfEstablishment
|
86 |
- occurrenceRemarks
|
87 |
+
MISC:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
custom_prompts/SLTPvA_short.yaml
CHANGED
@@ -48,35 +48,35 @@ rules:
|
|
48 |
minimumElevationInMeters: minimum elevation or altitude in meters.
|
49 |
maximumElevationInMeters: maximum elevation or altitude in meters.
|
50 |
mapping:
|
51 |
-
|
52 |
-
-
|
53 |
-
-
|
54 |
-
-
|
55 |
-
-
|
56 |
-
-
|
|
|
|
|
|
|
|
|
57 |
GEOGRAPHY:
|
58 |
- country
|
59 |
- stateProvince
|
60 |
- county
|
61 |
- municipality
|
62 |
-
- minimumElevationInMeters
|
63 |
-
- maximumElevationInMeters
|
64 |
-
LOCALITY:
|
65 |
-
- locality
|
66 |
-
- habitat
|
67 |
- decimalLatitude
|
68 |
- decimalLongitude
|
69 |
- verbatimCoordinates
|
70 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
71 |
- degreeOfEstablishment
|
72 |
- occurrenceRemarks
|
73 |
-
|
74 |
-
- catalogNumber
|
75 |
-
- order
|
76 |
-
- family
|
77 |
-
- scientificName
|
78 |
-
- scientificNameAuthorship
|
79 |
-
- genus
|
80 |
-
- subgenus
|
81 |
-
- specificEpithet
|
82 |
-
- infraspecificEpithet
|
|
|
48 |
minimumElevationInMeters: minimum elevation or altitude in meters.
|
49 |
maximumElevationInMeters: maximum elevation or altitude in meters.
|
50 |
mapping:
|
51 |
+
TAXONOMY:
|
52 |
+
- catalogNumber
|
53 |
+
- order
|
54 |
+
- family
|
55 |
+
- scientificName
|
56 |
+
- scientificNameAuthorship
|
57 |
+
- genus
|
58 |
+
- subgenus
|
59 |
+
- specificEpithet
|
60 |
+
- infraspecificEpithet
|
61 |
GEOGRAPHY:
|
62 |
- country
|
63 |
- stateProvince
|
64 |
- county
|
65 |
- municipality
|
|
|
|
|
|
|
|
|
|
|
66 |
- decimalLatitude
|
67 |
- decimalLongitude
|
68 |
- verbatimCoordinates
|
69 |
+
LOCALITY:
|
70 |
+
- locality
|
71 |
+
- habitat
|
72 |
+
- minimumElevationInMeters
|
73 |
+
- maximumElevationInMeters
|
74 |
+
COLLECTING:
|
75 |
+
- identifiedBy
|
76 |
+
- recordedBy
|
77 |
+
- recordNumber
|
78 |
+
- verbatimEventDate
|
79 |
+
- eventDate
|
80 |
- degreeOfEstablishment
|
81 |
- occurrenceRemarks
|
82 |
+
MISC:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
requirements.txt
CHANGED
Binary files a/requirements.txt and b/requirements.txt differ
|
|
vouchervision/LLM_GoogleGemini.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1 |
-
import os, time
|
2 |
import vertexai
|
3 |
from vertexai.preview.generative_models import GenerativeModel
|
4 |
from vertexai.generative_models._generative_models import HarmCategory, HarmBlockThreshold
|
@@ -9,10 +9,11 @@ from langchain_core.output_parsers import JsonOutputParser
|
|
9 |
from langchain_google_genai import ChatGoogleGenerativeAI
|
10 |
from langchain_google_vertexai import VertexAI
|
11 |
|
12 |
-
from vouchervision.utils_LLM import SystemLoadMonitor, count_tokens
|
13 |
from vouchervision.utils_LLM_JSON_validation import validate_and_align_JSON_keys_with_template
|
14 |
from vouchervision.utils_taxonomy_WFO import validate_taxonomy_WFO
|
15 |
from vouchervision.utils_geolocate_HERE import validate_coordinates_here
|
|
|
16 |
|
17 |
class GoogleGeminiHandler:
|
18 |
|
@@ -95,7 +96,8 @@ class GoogleGeminiHandler:
|
|
95 |
safety_settings=self.safety_settings)
|
96 |
return response.text
|
97 |
|
98 |
-
def call_llm_api_GoogleGemini(self, prompt_template, json_report):
|
|
|
99 |
self.json_report = json_report
|
100 |
self.json_report.set_text(text_main=f'Sending request to {self.model_name}')
|
101 |
self.monitor.start_monitoring_usage()
|
@@ -125,19 +127,26 @@ class GoogleGeminiHandler:
|
|
125 |
self.logger.error(f'[Attempt {ind}] Failed to extract JSON from:\n{response}')
|
126 |
self._adjust_config()
|
127 |
else:
|
128 |
-
|
|
|
|
|
129 |
output, WFO_record = validate_taxonomy_WFO(output, replace_if_success_wfo=False) ###################################### make this configurable
|
130 |
output, GEO_record = validate_coordinates_here(output, replace_if_success_geo=False) ###################################### make this configurable
|
131 |
|
132 |
-
|
|
|
|
|
|
|
|
|
|
|
133 |
|
134 |
-
self.monitor.stop_monitoring_report_usage()
|
135 |
|
136 |
if self.adjust_temp != self.starting_temp:
|
137 |
self._reset_config()
|
138 |
|
139 |
json_report.set_text(text_main=f'LLM call successful')
|
140 |
-
return output, nt_in, nt_out, WFO_record, GEO_record
|
141 |
|
142 |
except Exception as e:
|
143 |
self.logger.error(f'{e}')
|
@@ -148,10 +157,10 @@ class GoogleGeminiHandler:
|
|
148 |
self.logger.info(f"Failed to extract valid JSON after [{ind}] attempts")
|
149 |
self.json_report.set_text(text_main=f'Failed to extract valid JSON after [{ind}] attempts')
|
150 |
|
151 |
-
self.monitor.stop_monitoring_report_usage()
|
152 |
self._reset_config()
|
153 |
|
154 |
json_report.set_text(text_main=f'LLM call failed')
|
155 |
-
return None, nt_in, nt_out, None, None
|
156 |
|
157 |
|
|
|
1 |
+
import os, time, json
|
2 |
import vertexai
|
3 |
from vertexai.preview.generative_models import GenerativeModel
|
4 |
from vertexai.generative_models._generative_models import HarmCategory, HarmBlockThreshold
|
|
|
9 |
from langchain_google_genai import ChatGoogleGenerativeAI
|
10 |
from langchain_google_vertexai import VertexAI
|
11 |
|
12 |
+
from vouchervision.utils_LLM import SystemLoadMonitor, count_tokens, save_individual_prompt
|
13 |
from vouchervision.utils_LLM_JSON_validation import validate_and_align_JSON_keys_with_template
|
14 |
from vouchervision.utils_taxonomy_WFO import validate_taxonomy_WFO
|
15 |
from vouchervision.utils_geolocate_HERE import validate_coordinates_here
|
16 |
+
from vouchervision.tool_wikipedia import WikipediaLinks
|
17 |
|
18 |
class GoogleGeminiHandler:
|
19 |
|
|
|
96 |
safety_settings=self.safety_settings)
|
97 |
return response.text
|
98 |
|
99 |
+
def call_llm_api_GoogleGemini(self, prompt_template, json_report, paths):
|
100 |
+
_____, ____, _, __, ___, json_file_path_wiki, txt_file_path_ind_prompt = paths
|
101 |
self.json_report = json_report
|
102 |
self.json_report.set_text(text_main=f'Sending request to {self.model_name}')
|
103 |
self.monitor.start_monitoring_usage()
|
|
|
127 |
self.logger.error(f'[Attempt {ind}] Failed to extract JSON from:\n{response}')
|
128 |
self._adjust_config()
|
129 |
else:
|
130 |
+
self.monitor.stop_inference_timer() # Starts tool timer too
|
131 |
+
|
132 |
+
json_report.set_text(text_main=f'Working on WFO, Geolocation, Links')
|
133 |
output, WFO_record = validate_taxonomy_WFO(output, replace_if_success_wfo=False) ###################################### make this configurable
|
134 |
output, GEO_record = validate_coordinates_here(output, replace_if_success_geo=False) ###################################### make this configurable
|
135 |
|
136 |
+
Wiki = WikipediaLinks(json_file_path_wiki)
|
137 |
+
Wiki.gather_wikipedia_results(output)
|
138 |
+
|
139 |
+
save_individual_prompt(Wiki.sanitize(prompt_template), txt_file_path_ind_prompt)
|
140 |
+
|
141 |
+
self.logger.info(f"Formatted JSON:\n{json.dumps(output,indent=4)}")
|
142 |
|
143 |
+
usage_report = self.monitor.stop_monitoring_report_usage()
|
144 |
|
145 |
if self.adjust_temp != self.starting_temp:
|
146 |
self._reset_config()
|
147 |
|
148 |
json_report.set_text(text_main=f'LLM call successful')
|
149 |
+
return output, nt_in, nt_out, WFO_record, GEO_record, usage_report
|
150 |
|
151 |
except Exception as e:
|
152 |
self.logger.error(f'{e}')
|
|
|
157 |
self.logger.info(f"Failed to extract valid JSON after [{ind}] attempts")
|
158 |
self.json_report.set_text(text_main=f'Failed to extract valid JSON after [{ind}] attempts')
|
159 |
|
160 |
+
usage_report = self.monitor.stop_monitoring_report_usage()
|
161 |
self._reset_config()
|
162 |
|
163 |
json_report.set_text(text_main=f'LLM call failed')
|
164 |
+
return None, nt_in, nt_out, None, None, usage_report
|
165 |
|
166 |
|
vouchervision/LLM_GooglePalm2.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1 |
-
import os, time
|
2 |
import vertexai
|
3 |
from vertexai.language_models import TextGenerationModel
|
4 |
from vertexai.generative_models._generative_models import HarmCategory, HarmBlockThreshold
|
@@ -11,10 +11,11 @@ from langchain_core.output_parsers import JsonOutputParser
|
|
11 |
from langchain_google_genai import ChatGoogleGenerativeAI
|
12 |
from langchain_google_vertexai import VertexAI
|
13 |
|
14 |
-
from vouchervision.utils_LLM import SystemLoadMonitor, count_tokens
|
15 |
from vouchervision.utils_LLM_JSON_validation import validate_and_align_JSON_keys_with_template
|
16 |
from vouchervision.utils_taxonomy_WFO import validate_taxonomy_WFO
|
17 |
from vouchervision.utils_geolocate_HERE import validate_coordinates_here
|
|
|
18 |
|
19 |
#https://cloud.google.com/vertex-ai/docs/python-sdk/use-vertex-ai-python-sdk
|
20 |
#pip install --upgrade google-cloud-aiplatform
|
@@ -109,7 +110,8 @@ class GooglePalm2Handler:
|
|
109 |
return response.text
|
110 |
|
111 |
|
112 |
-
def call_llm_api_GooglePalm2(self, prompt_template, json_report):
|
|
|
113 |
self.json_report = json_report
|
114 |
self.json_report.set_text(text_main=f'Sending request to {self.model_name}')
|
115 |
self.monitor.start_monitoring_usage()
|
@@ -139,19 +141,26 @@ class GooglePalm2Handler:
|
|
139 |
self.logger.error(f'[Attempt {ind}] Failed to extract JSON from:\n{response}')
|
140 |
self._adjust_config()
|
141 |
else:
|
142 |
-
|
|
|
|
|
143 |
output, WFO_record = validate_taxonomy_WFO(output, replace_if_success_wfo=False) ###################################### make this configurable
|
144 |
output, GEO_record = validate_coordinates_here(output, replace_if_success_geo=False) ###################################### make this configurable
|
145 |
|
146 |
-
|
|
|
|
|
|
|
|
|
|
|
147 |
|
148 |
-
self.monitor.stop_monitoring_report_usage()
|
149 |
|
150 |
if self.adjust_temp != self.starting_temp:
|
151 |
self._reset_config()
|
152 |
|
153 |
json_report.set_text(text_main=f'LLM call successful')
|
154 |
-
return output, nt_in, nt_out, WFO_record, GEO_record
|
155 |
|
156 |
except Exception as e:
|
157 |
self.logger.error(f'{e}')
|
@@ -162,8 +171,8 @@ class GooglePalm2Handler:
|
|
162 |
self.logger.info(f"Failed to extract valid JSON after [{ind}] attempts")
|
163 |
self.json_report.set_text(text_main=f'Failed to extract valid JSON after [{ind}] attempts')
|
164 |
|
165 |
-
self.monitor.stop_monitoring_report_usage()
|
166 |
self._reset_config()
|
167 |
|
168 |
json_report.set_text(text_main=f'LLM call failed')
|
169 |
-
return None, nt_in, nt_out, None, None
|
|
|
1 |
+
import os, time, json
|
2 |
import vertexai
|
3 |
from vertexai.language_models import TextGenerationModel
|
4 |
from vertexai.generative_models._generative_models import HarmCategory, HarmBlockThreshold
|
|
|
11 |
from langchain_google_genai import ChatGoogleGenerativeAI
|
12 |
from langchain_google_vertexai import VertexAI
|
13 |
|
14 |
+
from vouchervision.utils_LLM import SystemLoadMonitor, count_tokens, save_individual_prompt
|
15 |
from vouchervision.utils_LLM_JSON_validation import validate_and_align_JSON_keys_with_template
|
16 |
from vouchervision.utils_taxonomy_WFO import validate_taxonomy_WFO
|
17 |
from vouchervision.utils_geolocate_HERE import validate_coordinates_here
|
18 |
+
from vouchervision.tool_wikipedia import WikipediaLinks
|
19 |
|
20 |
#https://cloud.google.com/vertex-ai/docs/python-sdk/use-vertex-ai-python-sdk
|
21 |
#pip install --upgrade google-cloud-aiplatform
|
|
|
110 |
return response.text
|
111 |
|
112 |
|
113 |
+
def call_llm_api_GooglePalm2(self, prompt_template, json_report, paths):
|
114 |
+
_____, ____, _, __, ___, json_file_path_wiki, txt_file_path_ind_prompt = paths
|
115 |
self.json_report = json_report
|
116 |
self.json_report.set_text(text_main=f'Sending request to {self.model_name}')
|
117 |
self.monitor.start_monitoring_usage()
|
|
|
141 |
self.logger.error(f'[Attempt {ind}] Failed to extract JSON from:\n{response}')
|
142 |
self._adjust_config()
|
143 |
else:
|
144 |
+
self.monitor.stop_inference_timer() # Starts tool timer too
|
145 |
+
|
146 |
+
json_report.set_text(text_main=f'Working on WFO, Geolocation, Links')
|
147 |
output, WFO_record = validate_taxonomy_WFO(output, replace_if_success_wfo=False) ###################################### make this configurable
|
148 |
output, GEO_record = validate_coordinates_here(output, replace_if_success_geo=False) ###################################### make this configurable
|
149 |
|
150 |
+
Wiki = WikipediaLinks(json_file_path_wiki)
|
151 |
+
Wiki.gather_wikipedia_results(output)
|
152 |
+
|
153 |
+
save_individual_prompt(Wiki.sanitize(prompt_template), txt_file_path_ind_prompt)
|
154 |
+
|
155 |
+
self.logger.info(f"Formatted JSON:\n{json.dumps(output,indent=4)}")
|
156 |
|
157 |
+
usage_report = self.monitor.stop_monitoring_report_usage()
|
158 |
|
159 |
if self.adjust_temp != self.starting_temp:
|
160 |
self._reset_config()
|
161 |
|
162 |
json_report.set_text(text_main=f'LLM call successful')
|
163 |
+
return output, nt_in, nt_out, WFO_record, GEO_record, usage_report
|
164 |
|
165 |
except Exception as e:
|
166 |
self.logger.error(f'{e}')
|
|
|
171 |
self.logger.info(f"Failed to extract valid JSON after [{ind}] attempts")
|
172 |
self.json_report.set_text(text_main=f'Failed to extract valid JSON after [{ind}] attempts')
|
173 |
|
174 |
+
usage_report = self.monitor.stop_monitoring_report_usage()
|
175 |
self._reset_config()
|
176 |
|
177 |
json_report.set_text(text_main=f'LLM call failed')
|
178 |
+
return None, nt_in, nt_out, None, None, usage_report
|
vouchervision/LLM_MistralAI.py
CHANGED
@@ -1,13 +1,14 @@
|
|
1 |
-
import os, time, random, torch
|
2 |
from langchain_mistralai.chat_models import ChatMistralAI
|
3 |
from langchain.output_parsers import RetryWithErrorOutputParser
|
4 |
from langchain.prompts import PromptTemplate
|
5 |
from langchain_core.output_parsers import JsonOutputParser
|
6 |
|
7 |
-
from vouchervision.utils_LLM import SystemLoadMonitor, count_tokens
|
8 |
from vouchervision.utils_LLM_JSON_validation import validate_and_align_JSON_keys_with_template
|
9 |
from vouchervision.utils_taxonomy_WFO import validate_taxonomy_WFO
|
10 |
from vouchervision.utils_geolocate_HERE import validate_coordinates_here
|
|
|
11 |
|
12 |
|
13 |
class MistralHandler:
|
@@ -78,7 +79,9 @@ class MistralHandler:
|
|
78 |
|
79 |
self.chain = self.prompt | self.llm_model
|
80 |
|
81 |
-
def call_llm_api_MistralAI(self, prompt_template, json_report):
|
|
|
|
|
82 |
self.json_report = json_report
|
83 |
self.json_report.set_text(text_main=f'Sending request to {self.model_name}')
|
84 |
self.monitor.start_monitoring_usage()
|
@@ -109,22 +112,29 @@ class MistralHandler:
|
|
109 |
self.logger.error(f'[Attempt {ind}] Failed to extract JSON from:\n{response}')
|
110 |
self._adjust_config()
|
111 |
else:
|
112 |
-
|
|
|
|
|
113 |
output, WFO_record = validate_taxonomy_WFO(output, replace_if_success_wfo=False) ###################################### make this configurable
|
114 |
output, GEO_record = validate_coordinates_here(output, replace_if_success_geo=False) ###################################### make this configurable
|
115 |
|
116 |
-
|
|
|
|
|
|
|
|
|
|
|
117 |
|
118 |
-
self.monitor.stop_monitoring_report_usage()
|
119 |
|
120 |
if self.adjust_temp != self.starting_temp:
|
121 |
self._reset_config()
|
122 |
|
123 |
json_report.set_text(text_main=f'LLM call successful')
|
124 |
-
return output, nt_in, nt_out, WFO_record, GEO_record
|
125 |
|
126 |
except Exception as e:
|
127 |
-
self.logger.error(f'{e}')
|
128 |
|
129 |
self._adjust_config()
|
130 |
time.sleep(self.RETRY_DELAY)
|
@@ -132,8 +142,8 @@ class MistralHandler:
|
|
132 |
self.logger.info(f"Failed to extract valid JSON after [{ind}] attempts")
|
133 |
self.json_report.set_text(text_main=f'Failed to extract valid JSON after [{ind}] attempts')
|
134 |
|
135 |
-
self.monitor.stop_monitoring_report_usage()
|
136 |
self._reset_config()
|
137 |
json_report.set_text(text_main=f'LLM call failed')
|
138 |
|
139 |
-
return None, nt_in, nt_out, None, None
|
|
|
1 |
+
import os, time, random, torch, json
|
2 |
from langchain_mistralai.chat_models import ChatMistralAI
|
3 |
from langchain.output_parsers import RetryWithErrorOutputParser
|
4 |
from langchain.prompts import PromptTemplate
|
5 |
from langchain_core.output_parsers import JsonOutputParser
|
6 |
|
7 |
+
from vouchervision.utils_LLM import SystemLoadMonitor, count_tokens, save_individual_prompt
|
8 |
from vouchervision.utils_LLM_JSON_validation import validate_and_align_JSON_keys_with_template
|
9 |
from vouchervision.utils_taxonomy_WFO import validate_taxonomy_WFO
|
10 |
from vouchervision.utils_geolocate_HERE import validate_coordinates_here
|
11 |
+
from vouchervision.tool_wikipedia import WikipediaLinks
|
12 |
|
13 |
|
14 |
class MistralHandler:
|
|
|
79 |
|
80 |
self.chain = self.prompt | self.llm_model
|
81 |
|
82 |
+
def call_llm_api_MistralAI(self, prompt_template, json_report, paths):
|
83 |
+
_____, ____, _, __, ___, json_file_path_wiki, txt_file_path_ind_prompt = paths
|
84 |
+
|
85 |
self.json_report = json_report
|
86 |
self.json_report.set_text(text_main=f'Sending request to {self.model_name}')
|
87 |
self.monitor.start_monitoring_usage()
|
|
|
112 |
self.logger.error(f'[Attempt {ind}] Failed to extract JSON from:\n{response}')
|
113 |
self._adjust_config()
|
114 |
else:
|
115 |
+
self.monitor.stop_inference_timer() # Starts tool timer too
|
116 |
+
|
117 |
+
json_report.set_text(text_main=f'Working on WFO, Geolocation, Links')
|
118 |
output, WFO_record = validate_taxonomy_WFO(output, replace_if_success_wfo=False) ###################################### make this configurable
|
119 |
output, GEO_record = validate_coordinates_here(output, replace_if_success_geo=False) ###################################### make this configurable
|
120 |
|
121 |
+
Wiki = WikipediaLinks(json_file_path_wiki)
|
122 |
+
Wiki.gather_wikipedia_results(output)
|
123 |
+
|
124 |
+
save_individual_prompt(Wiki.sanitize(prompt_template), txt_file_path_ind_prompt)
|
125 |
+
|
126 |
+
self.logger.info(f"Formatted JSON:\n{json.dumps(output,indent=4)}")
|
127 |
|
128 |
+
usage_report = self.monitor.stop_monitoring_report_usage()
|
129 |
|
130 |
if self.adjust_temp != self.starting_temp:
|
131 |
self._reset_config()
|
132 |
|
133 |
json_report.set_text(text_main=f'LLM call successful')
|
134 |
+
return output, nt_in, nt_out, WFO_record, GEO_record, usage_report
|
135 |
|
136 |
except Exception as e:
|
137 |
+
self.logger.error(f'JSON Parsing Error (LangChain): {e}')
|
138 |
|
139 |
self._adjust_config()
|
140 |
time.sleep(self.RETRY_DELAY)
|
|
|
142 |
self.logger.info(f"Failed to extract valid JSON after [{ind}] attempts")
|
143 |
self.json_report.set_text(text_main=f'Failed to extract valid JSON after [{ind}] attempts')
|
144 |
|
145 |
+
usage_report = self.monitor.stop_monitoring_report_usage()
|
146 |
self._reset_config()
|
147 |
json_report.set_text(text_main=f'LLM call failed')
|
148 |
|
149 |
+
return None, nt_in, nt_out, None, None, usage_report
|
vouchervision/LLM_OpenAI.py
CHANGED
@@ -1,14 +1,15 @@
|
|
1 |
-
import time, torch
|
2 |
from langchain.prompts import PromptTemplate
|
3 |
from langchain_openai import ChatOpenAI, OpenAI
|
4 |
from langchain.schema import HumanMessage
|
5 |
from langchain_core.output_parsers import JsonOutputParser
|
6 |
from langchain.output_parsers import RetryWithErrorOutputParser
|
7 |
|
8 |
-
from vouchervision.utils_LLM import SystemLoadMonitor, count_tokens
|
9 |
from vouchervision.utils_LLM_JSON_validation import validate_and_align_JSON_keys_with_template
|
10 |
from vouchervision.utils_taxonomy_WFO import validate_taxonomy_WFO
|
11 |
from vouchervision.utils_geolocate_HERE import validate_coordinates_here
|
|
|
12 |
|
13 |
class OpenAIHandler:
|
14 |
RETRY_DELAY = 10 # Wait 10 seconds before retrying
|
@@ -98,7 +99,8 @@ class OpenAIHandler:
|
|
98 |
self.chain = self.prompt | (self.format_input_for_azure if self.is_azure else ChatOpenAI(model=self.model_name))
|
99 |
|
100 |
|
101 |
-
def call_llm_api_OpenAI(self, prompt_template, json_report):
|
|
|
102 |
self.json_report = json_report
|
103 |
self.json_report.set_text(text_main=f'Sending request to {self.model_name}')
|
104 |
self.monitor.start_monitoring_usage()
|
@@ -130,19 +132,26 @@ class OpenAIHandler:
|
|
130 |
self.logger.error(f'[Attempt {ind}] Failed to extract JSON from:\n{response_text}')
|
131 |
self._adjust_config()
|
132 |
else:
|
133 |
-
|
134 |
|
|
|
135 |
output, WFO_record = validate_taxonomy_WFO(output, replace_if_success_wfo=False) ###################################### make this configurable
|
136 |
output, GEO_record = validate_coordinates_here(output, replace_if_success_geo=False) ###################################### make this configurable
|
137 |
|
138 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
139 |
|
140 |
-
self.monitor.stop_monitoring_report_usage()
|
141 |
-
|
142 |
if self.adjust_temp != self.starting_temp:
|
143 |
self._reset_config()
|
|
|
144 |
json_report.set_text(text_main=f'LLM call successful')
|
145 |
-
return output, nt_in, nt_out, WFO_record, GEO_record
|
146 |
|
147 |
except Exception as e:
|
148 |
self.logger.error(f'{e}')
|
@@ -153,8 +162,8 @@ class OpenAIHandler:
|
|
153 |
self.logger.info(f"Failed to extract valid JSON after [{ind}] attempts")
|
154 |
self.json_report.set_text(text_main=f'Failed to extract valid JSON after [{ind}] attempts')
|
155 |
|
156 |
-
self.monitor.stop_monitoring_report_usage()
|
157 |
self._reset_config()
|
158 |
|
159 |
json_report.set_text(text_main=f'LLM call failed')
|
160 |
-
return None, nt_in, nt_out, None, None
|
|
|
1 |
+
import time, torch, json
|
2 |
from langchain.prompts import PromptTemplate
|
3 |
from langchain_openai import ChatOpenAI, OpenAI
|
4 |
from langchain.schema import HumanMessage
|
5 |
from langchain_core.output_parsers import JsonOutputParser
|
6 |
from langchain.output_parsers import RetryWithErrorOutputParser
|
7 |
|
8 |
+
from vouchervision.utils_LLM import SystemLoadMonitor, count_tokens, save_individual_prompt
|
9 |
from vouchervision.utils_LLM_JSON_validation import validate_and_align_JSON_keys_with_template
|
10 |
from vouchervision.utils_taxonomy_WFO import validate_taxonomy_WFO
|
11 |
from vouchervision.utils_geolocate_HERE import validate_coordinates_here
|
12 |
+
from vouchervision.tool_wikipedia import WikipediaLinks
|
13 |
|
14 |
class OpenAIHandler:
|
15 |
RETRY_DELAY = 10 # Wait 10 seconds before retrying
|
|
|
99 |
self.chain = self.prompt | (self.format_input_for_azure if self.is_azure else ChatOpenAI(model=self.model_name))
|
100 |
|
101 |
|
102 |
+
def call_llm_api_OpenAI(self, prompt_template, json_report, paths):
|
103 |
+
_____, ____, _, __, ___, json_file_path_wiki, txt_file_path_ind_prompt = paths
|
104 |
self.json_report = json_report
|
105 |
self.json_report.set_text(text_main=f'Sending request to {self.model_name}')
|
106 |
self.monitor.start_monitoring_usage()
|
|
|
132 |
self.logger.error(f'[Attempt {ind}] Failed to extract JSON from:\n{response_text}')
|
133 |
self._adjust_config()
|
134 |
else:
|
135 |
+
self.monitor.stop_inference_timer() # Starts tool timer too
|
136 |
|
137 |
+
json_report.set_text(text_main=f'Working on WFO, Geolocation, Links')
|
138 |
output, WFO_record = validate_taxonomy_WFO(output, replace_if_success_wfo=False) ###################################### make this configurable
|
139 |
output, GEO_record = validate_coordinates_here(output, replace_if_success_geo=False) ###################################### make this configurable
|
140 |
|
141 |
+
Wiki = WikipediaLinks(json_file_path_wiki)
|
142 |
+
Wiki.gather_wikipedia_results(output)
|
143 |
+
|
144 |
+
save_individual_prompt(Wiki.sanitize(prompt_template), txt_file_path_ind_prompt)
|
145 |
+
|
146 |
+
self.logger.info(f"Formatted JSON:\n{json.dumps(output,indent=4)}")
|
147 |
+
|
148 |
+
usage_report = self.monitor.stop_monitoring_report_usage()
|
149 |
|
|
|
|
|
150 |
if self.adjust_temp != self.starting_temp:
|
151 |
self._reset_config()
|
152 |
+
|
153 |
json_report.set_text(text_main=f'LLM call successful')
|
154 |
+
return output, nt_in, nt_out, WFO_record, GEO_record, usage_report
|
155 |
|
156 |
except Exception as e:
|
157 |
self.logger.error(f'{e}')
|
|
|
162 |
self.logger.info(f"Failed to extract valid JSON after [{ind}] attempts")
|
163 |
self.json_report.set_text(text_main=f'Failed to extract valid JSON after [{ind}] attempts')
|
164 |
|
165 |
+
usage_report = self.monitor.stop_monitoring_report_usage()
|
166 |
self._reset_config()
|
167 |
|
168 |
json_report.set_text(text_main=f'LLM call failed')
|
169 |
+
return None, nt_in, nt_out, None, None, usage_report
|
vouchervision/LLM_local_MistralAI.py
CHANGED
@@ -6,10 +6,11 @@ from langchain_core.output_parsers import JsonOutputParser
|
|
6 |
from huggingface_hub import hf_hub_download
|
7 |
from langchain_community.llms.huggingface_pipeline import HuggingFacePipeline
|
8 |
|
9 |
-
from vouchervision.utils_LLM import SystemLoadMonitor, count_tokens
|
10 |
from vouchervision.utils_LLM_JSON_validation import validate_and_align_JSON_keys_with_template
|
11 |
from vouchervision.utils_taxonomy_WFO import validate_taxonomy_WFO
|
12 |
from vouchervision.utils_geolocate_HERE import validate_coordinates_here
|
|
|
13 |
|
14 |
'''
|
15 |
Local Pipielines:
|
@@ -147,7 +148,8 @@ class LocalMistralHandler:
|
|
147 |
self.chain = self.prompt | self.local_model # LCEL
|
148 |
|
149 |
|
150 |
-
def call_llm_local_MistralAI(self, prompt_template, json_report):
|
|
|
151 |
self.json_report = json_report
|
152 |
self.json_report.set_text(text_main=f'Sending request to {self.model_name}')
|
153 |
self.monitor.start_monitoring_usage()
|
@@ -183,20 +185,28 @@ class LocalMistralHandler:
|
|
183 |
self.logger.error(f'[Attempt {ind}] Failed to extract JSON from:\n{results}')
|
184 |
self._adjust_config()
|
185 |
else:
|
186 |
-
|
187 |
-
|
188 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
189 |
|
190 |
self.logger.info(f"Formatted JSON:\n{json.dumps(output,indent=4)}")
|
191 |
|
192 |
-
self.monitor.stop_monitoring_report_usage()
|
193 |
|
194 |
-
if self.adjust_temp != self.starting_temp:
|
195 |
self._reset_config()
|
196 |
-
|
197 |
json_report.set_text(text_main=f'LLM call successful')
|
198 |
del results
|
199 |
-
return output, nt_in, nt_out, WFO_record, GEO_record
|
|
|
200 |
except Exception as e:
|
201 |
self.logger.error(f'{e}')
|
202 |
self._adjust_config()
|
@@ -204,9 +214,9 @@ class LocalMistralHandler:
|
|
204 |
self.logger.info(f"Failed to extract valid JSON after [{ind}] attempts")
|
205 |
self.json_report.set_text(text_main=f'Failed to extract valid JSON after [{ind}] attempts')
|
206 |
|
207 |
-
self.monitor.stop_monitoring_report_usage()
|
208 |
json_report.set_text(text_main=f'LLM call failed')
|
209 |
|
210 |
self._reset_config()
|
211 |
-
return None, nt_in, nt_out, None, None
|
212 |
|
|
|
6 |
from huggingface_hub import hf_hub_download
|
7 |
from langchain_community.llms.huggingface_pipeline import HuggingFacePipeline
|
8 |
|
9 |
+
from vouchervision.utils_LLM import SystemLoadMonitor, count_tokens, save_individual_prompt
|
10 |
from vouchervision.utils_LLM_JSON_validation import validate_and_align_JSON_keys_with_template
|
11 |
from vouchervision.utils_taxonomy_WFO import validate_taxonomy_WFO
|
12 |
from vouchervision.utils_geolocate_HERE import validate_coordinates_here
|
13 |
+
from vouchervision.tool_wikipedia import WikipediaLinks
|
14 |
|
15 |
'''
|
16 |
Local Pipielines:
|
|
|
148 |
self.chain = self.prompt | self.local_model # LCEL
|
149 |
|
150 |
|
151 |
+
def call_llm_local_MistralAI(self, prompt_template, json_report, paths):
|
152 |
+
_____, ____, _, __, ___, json_file_path_wiki, txt_file_path_ind_prompt = paths
|
153 |
self.json_report = json_report
|
154 |
self.json_report.set_text(text_main=f'Sending request to {self.model_name}')
|
155 |
self.monitor.start_monitoring_usage()
|
|
|
185 |
self.logger.error(f'[Attempt {ind}] Failed to extract JSON from:\n{results}')
|
186 |
self._adjust_config()
|
187 |
else:
|
188 |
+
self.monitor.stop_inference_timer() # Starts tool timer too
|
189 |
+
|
190 |
+
json_report.set_text(text_main=f'Working on WFO, Geolocation, Links')
|
191 |
+
output, WFO_record = validate_taxonomy_WFO(output, replace_if_success_wfo=False) ###################################### make this configurable
|
192 |
+
output, GEO_record = validate_coordinates_here(output, replace_if_success_geo=False) ###################################### make this configurable
|
193 |
+
|
194 |
+
Wiki = WikipediaLinks(json_file_path_wiki)
|
195 |
+
Wiki.gather_wikipedia_results(output)
|
196 |
+
|
197 |
+
save_individual_prompt(Wiki.sanitize(prompt_template), txt_file_path_ind_prompt)
|
198 |
|
199 |
self.logger.info(f"Formatted JSON:\n{json.dumps(output,indent=4)}")
|
200 |
|
201 |
+
usage_report = self.monitor.stop_monitoring_report_usage()
|
202 |
|
203 |
+
if self.adjust_temp != self.starting_temp:
|
204 |
self._reset_config()
|
205 |
+
|
206 |
json_report.set_text(text_main=f'LLM call successful')
|
207 |
del results
|
208 |
+
return output, nt_in, nt_out, WFO_record, GEO_record, usage_report
|
209 |
+
|
210 |
except Exception as e:
|
211 |
self.logger.error(f'{e}')
|
212 |
self._adjust_config()
|
|
|
214 |
self.logger.info(f"Failed to extract valid JSON after [{ind}] attempts")
|
215 |
self.json_report.set_text(text_main=f'Failed to extract valid JSON after [{ind}] attempts')
|
216 |
|
217 |
+
usage_report = self.monitor.stop_monitoring_report_usage()
|
218 |
json_report.set_text(text_main=f'LLM call failed')
|
219 |
|
220 |
self._reset_config()
|
221 |
+
return None, nt_in, nt_out, None, None, usage_report
|
222 |
|
vouchervision/LLM_local_cpu_MistralAI.py
CHANGED
@@ -18,10 +18,11 @@ from langchain.callbacks.base import BaseCallbackHandler
|
|
18 |
from huggingface_hub import hf_hub_download
|
19 |
|
20 |
|
21 |
-
from vouchervision.utils_LLM import SystemLoadMonitor, count_tokens
|
22 |
from vouchervision.utils_LLM_JSON_validation import validate_and_align_JSON_keys_with_template
|
23 |
from vouchervision.utils_taxonomy_WFO import validate_taxonomy_WFO
|
24 |
from vouchervision.utils_geolocate_HERE import validate_coordinates_here
|
|
|
25 |
|
26 |
class LocalCPUMistralHandler:
|
27 |
RETRY_DELAY = 2 # Wait 2 seconds before retrying
|
@@ -136,7 +137,8 @@ class LocalCPUMistralHandler:
|
|
136 |
self.chain = self.prompt | self.local_model
|
137 |
|
138 |
|
139 |
-
def call_llm_local_cpu_MistralAI(self, prompt_template, json_report):
|
|
|
140 |
self.json_report = json_report
|
141 |
self.json_report.set_text(text_main=f'Sending request to {self.model_name}')
|
142 |
self.monitor.start_monitoring_usage()
|
@@ -176,18 +178,26 @@ class LocalCPUMistralHandler:
|
|
176 |
self.logger.error(f'[Attempt {ind}] Failed to extract JSON from:\n{results}')
|
177 |
self._adjust_config()
|
178 |
else:
|
179 |
-
|
180 |
-
|
181 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
182 |
|
183 |
self.logger.info(f"Formatted JSON:\n{json.dumps(output,indent=4)}")
|
184 |
|
185 |
-
self.monitor.stop_monitoring_report_usage()
|
186 |
|
187 |
-
if self.adjust_temp != self.starting_temp:
|
188 |
self._reset_config()
|
|
|
189 |
json_report.set_text(text_main=f'LLM call successful')
|
190 |
-
return output, nt_in, nt_out, WFO_record, GEO_record
|
191 |
|
192 |
except Exception as e:
|
193 |
self.logger.error(f'{e}')
|
@@ -196,10 +206,10 @@ class LocalCPUMistralHandler:
|
|
196 |
self.logger.info(f"Failed to extract valid JSON after [{ind}] attempts")
|
197 |
self.json_report.set_text(text_main=f'Failed to extract valid JSON after [{ind}] attempts')
|
198 |
|
199 |
-
self.monitor.stop_monitoring_report_usage()
|
200 |
self._reset_config()
|
201 |
|
202 |
json_report.set_text(text_main=f'LLM call failed')
|
203 |
-
return None, nt_in, nt_out, None, None
|
204 |
|
205 |
|
|
|
18 |
from huggingface_hub import hf_hub_download
|
19 |
|
20 |
|
21 |
+
from vouchervision.utils_LLM import SystemLoadMonitor, count_tokens, save_individual_prompt
|
22 |
from vouchervision.utils_LLM_JSON_validation import validate_and_align_JSON_keys_with_template
|
23 |
from vouchervision.utils_taxonomy_WFO import validate_taxonomy_WFO
|
24 |
from vouchervision.utils_geolocate_HERE import validate_coordinates_here
|
25 |
+
from vouchervision.tool_wikipedia import WikipediaLinks
|
26 |
|
27 |
class LocalCPUMistralHandler:
|
28 |
RETRY_DELAY = 2 # Wait 2 seconds before retrying
|
|
|
137 |
self.chain = self.prompt | self.local_model
|
138 |
|
139 |
|
140 |
+
def call_llm_local_cpu_MistralAI(self, prompt_template, json_report, paths):
|
141 |
+
_____, ____, _, __, ___, json_file_path_wiki, txt_file_path_ind_prompt = paths
|
142 |
self.json_report = json_report
|
143 |
self.json_report.set_text(text_main=f'Sending request to {self.model_name}')
|
144 |
self.monitor.start_monitoring_usage()
|
|
|
178 |
self.logger.error(f'[Attempt {ind}] Failed to extract JSON from:\n{results}')
|
179 |
self._adjust_config()
|
180 |
else:
|
181 |
+
self.monitor.stop_inference_timer() # Starts tool timer too
|
182 |
+
|
183 |
+
json_report.set_text(text_main=f'Working on WFO, Geolocation, Links')
|
184 |
+
output, WFO_record = validate_taxonomy_WFO(output, replace_if_success_wfo=False) ###################################### make this configurable
|
185 |
+
output, GEO_record = validate_coordinates_here(output, replace_if_success_geo=False) ###################################### make this configurable
|
186 |
+
|
187 |
+
Wiki = WikipediaLinks(json_file_path_wiki)
|
188 |
+
Wiki.gather_wikipedia_results(output)
|
189 |
+
|
190 |
+
save_individual_prompt(Wiki.sanitize(prompt_template), txt_file_path_ind_prompt)
|
191 |
|
192 |
self.logger.info(f"Formatted JSON:\n{json.dumps(output,indent=4)}")
|
193 |
|
194 |
+
usage_report = self.monitor.stop_monitoring_report_usage()
|
195 |
|
196 |
+
if self.adjust_temp != self.starting_temp:
|
197 |
self._reset_config()
|
198 |
+
|
199 |
json_report.set_text(text_main=f'LLM call successful')
|
200 |
+
return output, nt_in, nt_out, WFO_record, GEO_record, usage_report
|
201 |
|
202 |
except Exception as e:
|
203 |
self.logger.error(f'{e}')
|
|
|
206 |
self.logger.info(f"Failed to extract valid JSON after [{ind}] attempts")
|
207 |
self.json_report.set_text(text_main=f'Failed to extract valid JSON after [{ind}] attempts')
|
208 |
|
209 |
+
usage_report = self.monitor.stop_monitoring_report_usage()
|
210 |
self._reset_config()
|
211 |
|
212 |
json_report.set_text(text_main=f'LLM call failed')
|
213 |
+
return None, nt_in, nt_out, None, None, usage_report
|
214 |
|
215 |
|
vouchervision/directory_structure_VV.py
CHANGED
@@ -92,6 +92,16 @@ class Dir_Structure():
|
|
92 |
self.transcription_ind_OCR_helper = os.path.join(self.dir_project,'Transcription','Individual_OCR_Helper')
|
93 |
validate_dir(self.transcription_ind_OCR_helper)
|
94 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
95 |
self.save_original = os.path.join(self.dir_project,'Original_Images')
|
96 |
validate_dir(self.save_original)
|
97 |
|
|
|
92 |
self.transcription_ind_OCR_helper = os.path.join(self.dir_project,'Transcription','Individual_OCR_Helper')
|
93 |
validate_dir(self.transcription_ind_OCR_helper)
|
94 |
|
95 |
+
self.transcription_ind_wiki = os.path.join(self.dir_project,'Transcription','Individual_Wikipedia')
|
96 |
+
validate_dir(self.transcription_ind_wiki)
|
97 |
+
|
98 |
+
self.transcription_ind_prompt = os.path.join(self.dir_project,'Transcription','Individual_Prompt')
|
99 |
+
validate_dir(self.transcription_ind_prompt)
|
100 |
+
self.transcription_prompt = os.path.join(self.dir_project,'Transcription','Prompt_Template')
|
101 |
+
validate_dir(self.transcription_prompt)
|
102 |
+
|
103 |
+
|
104 |
+
|
105 |
self.save_original = os.path.join(self.dir_project,'Original_Images')
|
106 |
validate_dir(self.save_original)
|
107 |
|
vouchervision/prompt_catalog.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1 |
from dataclasses import dataclass
|
2 |
from langchain_core.pydantic_v1 import Field, create_model
|
3 |
-
import yaml, json
|
4 |
|
5 |
@dataclass
|
6 |
class PromptCatalog:
|
@@ -69,6 +69,23 @@ class PromptCatalog:
|
|
69 |
# return prompt, self.PromptJSONModel, self.n_fields, xlsx_headers
|
70 |
return prompt, self.dictionary_structure
|
71 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
72 |
def load_rules_config(self):
|
73 |
with open(self.rules_config_path, 'r') as stream:
|
74 |
try:
|
|
|
1 |
from dataclasses import dataclass
|
2 |
from langchain_core.pydantic_v1 import Field, create_model
|
3 |
+
import yaml, json, os, shutil
|
4 |
|
5 |
@dataclass
|
6 |
class PromptCatalog:
|
|
|
69 |
# return prompt, self.PromptJSONModel, self.n_fields, xlsx_headers
|
70 |
return prompt, self.dictionary_structure
|
71 |
|
72 |
+
|
73 |
+
def copy_prompt_template_to_new_dir(self, new_directory_path, rules_config_path):
|
74 |
+
# Ensure the target directory exists, create it if it doesn't
|
75 |
+
if not os.path.exists(new_directory_path):
|
76 |
+
os.makedirs(new_directory_path)
|
77 |
+
|
78 |
+
# Define the path for the new file location
|
79 |
+
new_file_path = os.path.join(new_directory_path, os.path.basename(rules_config_path))
|
80 |
+
|
81 |
+
# Copy the file to the new location
|
82 |
+
try:
|
83 |
+
shutil.copy(rules_config_path, new_file_path)
|
84 |
+
print(f"Prompt [{os.path.basename(rules_config_path)}] copied successfully to {new_file_path}")
|
85 |
+
except Exception as exc:
|
86 |
+
print(f"Error copying [{os.path.basename(rules_config_path)}] file: {exc}")
|
87 |
+
|
88 |
+
|
89 |
def load_rules_config(self):
|
90 |
with open(self.rules_config_path, 'r') as stream:
|
91 |
try:
|
vouchervision/tool_wikipedia.py
ADDED
@@ -0,0 +1,581 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import itertools, yaml,wikipediaapi, requests, re, json
|
2 |
+
from langchain_community.tools import WikipediaQueryRun
|
3 |
+
from langchain_community.utilities import WikipediaAPIWrapper
|
4 |
+
# from langchain_community.tools.wikidata.tool import WikidataAPIWrapper, WikidataQueryRun
|
5 |
+
|
6 |
+
|
7 |
+
class WikipediaLinks():
|
8 |
+
|
9 |
+
|
10 |
+
def __init__(self, json_file_path_wiki) -> None:
|
11 |
+
self.json_file_path_wiki = json_file_path_wiki
|
12 |
+
self.wiki_wiki = wikipediaapi.Wikipedia(
|
13 |
+
user_agent='VoucherVision ([email protected])',
|
14 |
+
language='en'
|
15 |
+
)
|
16 |
+
self.property_to_rank = {
|
17 |
+
'P225': 'Species',
|
18 |
+
'P171': 'Family',
|
19 |
+
'P105': 'Taxon rank',
|
20 |
+
'P70': 'Genus',
|
21 |
+
'P75': 'Clade',
|
22 |
+
'P76': 'Subgenus',
|
23 |
+
'P67': 'Subfamily',
|
24 |
+
'P66': 'Tribe',
|
25 |
+
'P71': 'Subtribe',
|
26 |
+
'P61': 'Order',
|
27 |
+
'P72': 'Suborder',
|
28 |
+
'P73': 'Infraorder',
|
29 |
+
'P74': 'Superfamily',
|
30 |
+
'P142': 'Phylum',
|
31 |
+
'P75': 'Clade',
|
32 |
+
'P76': 'Subclass',
|
33 |
+
'P77': 'Infraclass',
|
34 |
+
'P78': 'Superorder',
|
35 |
+
'P81': 'Class',
|
36 |
+
'P82': 'Superclass',
|
37 |
+
'P84': 'Kingdom',
|
38 |
+
'P85': 'Superkingdom',
|
39 |
+
'P86': 'Subkingdom',
|
40 |
+
'P87': 'Infrakingdom',
|
41 |
+
'P88': 'Parvkingdom',
|
42 |
+
'P89': 'Domain',
|
43 |
+
'P1421': 'GRIN',
|
44 |
+
'P1070': 'KEW',
|
45 |
+
'P5037': 'POWOID',
|
46 |
+
}
|
47 |
+
|
48 |
+
|
49 |
+
def get_label_for_entity_id(self, entity_id):
|
50 |
+
url = "https://www.wikidata.org/w/api.php"
|
51 |
+
params = {
|
52 |
+
"action": "wbgetentities",
|
53 |
+
"format": "json",
|
54 |
+
"ids": entity_id,
|
55 |
+
"props": "labels",
|
56 |
+
"languages": "en" # Assuming you want the label in English
|
57 |
+
}
|
58 |
+
response = requests.get(url, params=params)
|
59 |
+
data = response.json()
|
60 |
+
return data['entities'][entity_id]['labels']['en']['value'] if 'en' in data['entities'][entity_id]['labels'] else None
|
61 |
+
|
62 |
+
|
63 |
+
def is_valid_url(self, url):
|
64 |
+
try:
|
65 |
+
response = requests.head(url, allow_redirects=True, timeout=5)
|
66 |
+
# If the response status code is 200, the URL is reachable
|
67 |
+
return response.status_code == 200
|
68 |
+
except requests.RequestException as e:
|
69 |
+
# If there was some issue with the request, such as the domain does not exist
|
70 |
+
# print(f"URL {url} is not reachable. Error: {e}")
|
71 |
+
return False
|
72 |
+
|
73 |
+
# def get_infobar_data(self, wiki_page_title):
|
74 |
+
# # Step 1: Extract the Wikidata Item ID from the Wikipedia page
|
75 |
+
# wiki_api_url = "https://en.wikipedia.org/w/api.php"
|
76 |
+
# wiki_params = {
|
77 |
+
# "action": "query",
|
78 |
+
# "format": "json",
|
79 |
+
# "titles": wiki_page_title,
|
80 |
+
# "prop": "revisions",
|
81 |
+
# "rvprop": "content",
|
82 |
+
# "rvslots": "*"
|
83 |
+
# }
|
84 |
+
|
85 |
+
# wiki_response = requests.get(wiki_api_url, params=wiki_params)
|
86 |
+
# wiki_data = wiki_response.json()
|
87 |
+
|
88 |
+
# page_key = next(iter(wiki_data['query']['pages']))
|
89 |
+
# content = wiki_data['query']['pages'][page_key]['revisions'][0]['slots']['main']['*']
|
90 |
+
|
91 |
+
# infobox_pattern = re.compile(r'\{\{Infobox.*?\|title\}\}', re.DOTALL)
|
92 |
+
# match = infobox_pattern.search(content)
|
93 |
+
# if match:
|
94 |
+
# wikidata_id = match.group(1) # Returns the full match including the 'Infobox' braces
|
95 |
+
# else:
|
96 |
+
# return "Infobox not found"
|
97 |
+
|
98 |
+
# # Step 2: Fetch Data from Wikidata Using the Extracted ID
|
99 |
+
# wikidata_api_url = "https://www.wikidata.org/w/api.php"
|
100 |
+
# wikidata_params = {
|
101 |
+
# "action": "wbgetentities",
|
102 |
+
# "format": "json",
|
103 |
+
# "ids": wikidata_id,
|
104 |
+
# "props": "claims" # Adjust as needed to fetch the desired data
|
105 |
+
# }
|
106 |
+
|
107 |
+
# wikidata_response = requests.get(wikidata_api_url, params=wikidata_params)
|
108 |
+
# wikidata_content = wikidata_response.json()
|
109 |
+
|
110 |
+
|
111 |
+
# classification_full = {}
|
112 |
+
# classification = {}
|
113 |
+
# label_cache = {} # Cache for labels
|
114 |
+
|
115 |
+
|
116 |
+
# # Turn this on to see the available properties to decode
|
117 |
+
# for prop_id, claims in wikidata_content['entities'][wikidata_id]['claims'].items():
|
118 |
+
# # Assuming the main snak value is what we want
|
119 |
+
# value = claims[0]['mainsnak']['datavalue']['value']
|
120 |
+
# if isinstance(value, dict): # If the value is an entity ID
|
121 |
+
# # entity_id = value['id']
|
122 |
+
# # entity_id = value['id']
|
123 |
+
# if prop_id not in label_cache:
|
124 |
+
# label_cache[prop_id] = self.get_label_for_entity_id(prop_id)
|
125 |
+
# classification_full[prop_id] = label_cache[prop_id]
|
126 |
+
# else:
|
127 |
+
# classification_full[prop_id] = value
|
128 |
+
# print(classification_full)
|
129 |
+
# Map Wikidata properties to the corresponding taxonomic ranks
|
130 |
+
|
131 |
+
def convert_to_decimal(self, coord_parts):
|
132 |
+
lat_deg, lat_min, lat_dir, lon_deg, lon_min, lon_dir = coord_parts[:6]
|
133 |
+
|
134 |
+
lat = float(lat_deg) + float(lat_min) / 60
|
135 |
+
lon = float(lon_deg) + float(lon_min) / 60
|
136 |
+
|
137 |
+
if lat_dir == 'S':
|
138 |
+
lat = -lat
|
139 |
+
if lon_dir == 'W':
|
140 |
+
lon = -lon
|
141 |
+
|
142 |
+
return f"{lat},{lon}"
|
143 |
+
|
144 |
+
|
145 |
+
def extract_coordinates_and_region(self, coord_string):
|
146 |
+
# Extract the coordinate parts and region info
|
147 |
+
coord_parts = re.findall(r'(\d+|\w+)', coord_string)
|
148 |
+
region_info = re.search(r'region:([^|]+)\|display', coord_string)
|
149 |
+
|
150 |
+
if coord_parts and len(coord_parts) >= 6:
|
151 |
+
# Convert to decimal coordinates
|
152 |
+
decimal_coords = self.convert_to_decimal(coord_parts)
|
153 |
+
else:
|
154 |
+
decimal_coords = "Invalid coordinates format"
|
155 |
+
|
156 |
+
region = region_info.group(1) if region_info else "Region not found"
|
157 |
+
return decimal_coords, region
|
158 |
+
|
159 |
+
|
160 |
+
def parse_infobox(self, infobox_string):
|
161 |
+
# Split the string into lines
|
162 |
+
lines = infobox_string.split('\n')
|
163 |
+
|
164 |
+
# Dictionary to store the extracted data
|
165 |
+
infobox_data = {}
|
166 |
+
|
167 |
+
# Iterate over each line
|
168 |
+
for line in lines:
|
169 |
+
# Split the line into key and value
|
170 |
+
parts = line.split('=', 1)
|
171 |
+
|
172 |
+
# If the line is properly formatted with a key and value
|
173 |
+
if len(parts) == 2:
|
174 |
+
key = parts[0].strip()
|
175 |
+
key = key.split(' ')[1]
|
176 |
+
value = parts[1].strip()
|
177 |
+
|
178 |
+
# Handling special cases like links or coordinates
|
179 |
+
if value.startswith('[[') and value.endswith(']]'):
|
180 |
+
# Extracting linked article titles
|
181 |
+
value = value[2:-2].split('|')[0]
|
182 |
+
elif value.startswith('{{coord') and value.endswith('}}'):
|
183 |
+
# Extracting coordinates
|
184 |
+
value = value[7:-2]
|
185 |
+
elif value.startswith('[') and value.endswith(']') and ('http' in value):
|
186 |
+
value = value[1:-1]
|
187 |
+
url_parts = value.split(" ")
|
188 |
+
infobox_data['url_location'] = next((part for part in url_parts if 'http' in part), None)
|
189 |
+
|
190 |
+
if key == 'coordinates':
|
191 |
+
decimal_coordinates, region = self.extract_coordinates_and_region(value)
|
192 |
+
infobox_data['region'] = region
|
193 |
+
infobox_data['decimal_coordinates'] = decimal_coordinates
|
194 |
+
|
195 |
+
key = self.sanitize(key)
|
196 |
+
value = self.sanitize(value)
|
197 |
+
value = self.remove_html_and_wiki_markup(value)
|
198 |
+
# Add to dictionary
|
199 |
+
infobox_data[key] = value
|
200 |
+
|
201 |
+
return infobox_data
|
202 |
+
|
203 |
+
def get_infobox_data(self, wiki_page_title, opt=None):
|
204 |
+
wiki_api_url = "https://en.wikipedia.org/w/api.php"
|
205 |
+
wiki_params = {
|
206 |
+
"action": "query",
|
207 |
+
"format": "json",
|
208 |
+
"titles": wiki_page_title,
|
209 |
+
"prop": "revisions",
|
210 |
+
"rvprop": "content",
|
211 |
+
"rvslots": "*"
|
212 |
+
}
|
213 |
+
|
214 |
+
try:
|
215 |
+
wiki_response = requests.get(wiki_api_url, params=wiki_params)
|
216 |
+
wiki_response.raise_for_status() # Check for HTTP errors
|
217 |
+
except requests.RequestException as e:
|
218 |
+
return f"Error fetching data: {e}"
|
219 |
+
|
220 |
+
wiki_data = wiki_response.json()
|
221 |
+
|
222 |
+
page_key = next(iter(wiki_data['query']['pages']), None)
|
223 |
+
if page_key is None or "missing" in wiki_data['query']['pages'][page_key]:
|
224 |
+
return "Page not found"
|
225 |
+
|
226 |
+
content = wiki_data['query']['pages'][page_key]['revisions'][0]['slots']['main']['*']
|
227 |
+
|
228 |
+
infobox_pattern = re.compile(r'\{\{Infobox.*?\}\}', re.DOTALL)
|
229 |
+
match = infobox_pattern.search(content)
|
230 |
+
|
231 |
+
if match:
|
232 |
+
infobox_content = match.group()
|
233 |
+
else:
|
234 |
+
self.infobox_data = {}
|
235 |
+
self.infobox_data_locality = {}
|
236 |
+
return "Infobox not found"
|
237 |
+
|
238 |
+
if opt is None:
|
239 |
+
self.infobox_data = self.parse_infobox(infobox_content)
|
240 |
+
else:
|
241 |
+
self.infobox_data_locality = self.parse_infobox(infobox_content)
|
242 |
+
|
243 |
+
|
244 |
+
|
245 |
+
# Example usage
|
246 |
+
|
247 |
+
# for prop_id, claims in wikidata_content['entities'][wikidata_id]['claims'].items():
|
248 |
+
# # Get the taxonomic rank from the mapping
|
249 |
+
# rank = self.property_to_rank.get(prop_id)
|
250 |
+
# if rank:
|
251 |
+
# value = claims[0]['mainsnak']['datavalue']['value']
|
252 |
+
# if isinstance(value, dict): # If the value is an entity ID
|
253 |
+
# entity_id = value['id']
|
254 |
+
# if entity_id not in label_cache:
|
255 |
+
# label_cache[entity_id] = self.get_label_for_entity_id(entity_id)
|
256 |
+
# classification[rank] = label_cache[entity_id]
|
257 |
+
# else:
|
258 |
+
# classification[rank] = value
|
259 |
+
|
260 |
+
# try:
|
261 |
+
# unknown_link = "https://powo.science.kew.org/taxon/" + classification['POWOID']
|
262 |
+
# if self.is_valid_url(unknown_link):
|
263 |
+
# classification['POWOID'] = unknown_link
|
264 |
+
# classification['POWOID_syn'] = unknown_link + '#synonyms'
|
265 |
+
# except:
|
266 |
+
# pass
|
267 |
+
# return classification
|
268 |
+
|
269 |
+
|
270 |
+
|
271 |
+
def get_taxonbar_data(self, wiki_page_title):
|
272 |
+
# Step 1: Extract the Wikidata Item ID from the Wikipedia page
|
273 |
+
wiki_api_url = "https://en.wikipedia.org/w/api.php"
|
274 |
+
wiki_params = {
|
275 |
+
"action": "query",
|
276 |
+
"format": "json",
|
277 |
+
"titles": wiki_page_title,
|
278 |
+
"prop": "revisions",
|
279 |
+
"rvprop": "content",
|
280 |
+
"rvslots": "*"
|
281 |
+
}
|
282 |
+
|
283 |
+
wiki_response = requests.get(wiki_api_url, params=wiki_params)
|
284 |
+
wiki_data = wiki_response.json()
|
285 |
+
|
286 |
+
page_key = next(iter(wiki_data['query']['pages']))
|
287 |
+
content = wiki_data['query']['pages'][page_key]['revisions'][0]['slots']['main']['*']
|
288 |
+
|
289 |
+
taxonbar_match = re.search(r'\{\{Taxonbar\|from=(Q\d+)\}\}', content)
|
290 |
+
if not taxonbar_match:
|
291 |
+
return "Taxonbar not found"
|
292 |
+
|
293 |
+
wikidata_id = taxonbar_match.group(1)
|
294 |
+
|
295 |
+
# Step 2: Fetch Data from Wikidata Using the Extracted ID
|
296 |
+
wikidata_api_url = "https://www.wikidata.org/w/api.php"
|
297 |
+
wikidata_params = {
|
298 |
+
"action": "wbgetentities",
|
299 |
+
"format": "json",
|
300 |
+
"ids": wikidata_id,
|
301 |
+
"props": "claims" # Adjust as needed to fetch the desired data
|
302 |
+
}
|
303 |
+
|
304 |
+
wikidata_response = requests.get(wikidata_api_url, params=wikidata_params)
|
305 |
+
wikidata_content = wikidata_response.json()
|
306 |
+
|
307 |
+
|
308 |
+
classification_full = {}
|
309 |
+
classification = {}
|
310 |
+
label_cache = {} # Cache for labels
|
311 |
+
|
312 |
+
|
313 |
+
# Turn this on to see the available properties to decode
|
314 |
+
# for prop_id, claims in wikidata_content['entities'][wikidata_id]['claims'].items():
|
315 |
+
# # Assuming the main snak value is what we want
|
316 |
+
# value = claims[0]['mainsnak']['datavalue']['value']
|
317 |
+
# if isinstance(value, dict): # If the value is an entity ID
|
318 |
+
# # entity_id = value['id']
|
319 |
+
# # entity_id = value['id']
|
320 |
+
# if prop_id not in label_cache:
|
321 |
+
# label_cache[prop_id] = self.get_label_for_entity_id(prop_id)
|
322 |
+
# classification_full[prop_id] = label_cache[prop_id]
|
323 |
+
# else:
|
324 |
+
# classification_full[prop_id] = value
|
325 |
+
# print(classification_full)
|
326 |
+
# Map Wikidata properties to the corresponding taxonomic ranks
|
327 |
+
|
328 |
+
|
329 |
+
for prop_id, claims in wikidata_content['entities'][wikidata_id]['claims'].items():
|
330 |
+
# Get the taxonomic rank from the mapping
|
331 |
+
rank = self.property_to_rank.get(prop_id)
|
332 |
+
if rank:
|
333 |
+
value = claims[0]['mainsnak']['datavalue']['value']
|
334 |
+
if isinstance(value, dict): # If the value is an entity ID
|
335 |
+
entity_id = value['id']
|
336 |
+
if entity_id not in label_cache:
|
337 |
+
label_cache[entity_id] = self.get_label_for_entity_id(entity_id)
|
338 |
+
classification[rank] = label_cache[entity_id]
|
339 |
+
else:
|
340 |
+
classification[rank] = value
|
341 |
+
|
342 |
+
try:
|
343 |
+
unknown_link = "https://powo.science.kew.org/taxon/" + classification['POWOID']
|
344 |
+
if self.is_valid_url(unknown_link):
|
345 |
+
classification['POWOID'] = unknown_link
|
346 |
+
classification['POWOID_syn'] = unknown_link + '#synonyms'
|
347 |
+
except:
|
348 |
+
pass
|
349 |
+
return classification
|
350 |
+
|
351 |
+
|
352 |
+
def extract_page_title(self, result_string):
|
353 |
+
first_line = result_string.split('\n')[0]
|
354 |
+
page_title = first_line.replace('Page: ', '').strip()
|
355 |
+
return page_title
|
356 |
+
|
357 |
+
|
358 |
+
def get_wikipedia_url(self, page_title):
|
359 |
+
page = self.wiki_wiki.page(page_title)
|
360 |
+
if page.exists():
|
361 |
+
return page.fullurl
|
362 |
+
else:
|
363 |
+
return None
|
364 |
+
|
365 |
+
|
366 |
+
def extract_info_taxa(self, page):
|
367 |
+
links = []
|
368 |
+
self.info_packet['WIKI_TAXA']['LINKS'] = {}
|
369 |
+
self.info_packet['WIKI_TAXA']['DATA'] = {}
|
370 |
+
|
371 |
+
self.info_packet['WIKI_TAXA']['DATA'].update(self.get_taxonbar_data(page.title))
|
372 |
+
|
373 |
+
for back in page.backlinks:
|
374 |
+
back = self.sanitize(back)
|
375 |
+
if ':' not in back:
|
376 |
+
link = self.sanitize(self.get_wikipedia_url(back))
|
377 |
+
if link not in links:
|
378 |
+
links.append(link)
|
379 |
+
self.info_packet['WIKI_TAXA']['LINKS'][back] = link
|
380 |
+
|
381 |
+
|
382 |
+
def extract_info_geo(self, page, opt=None):
|
383 |
+
links = []
|
384 |
+
self.info_packet['WIKI_GEO']['LINKS'] = {}
|
385 |
+
if opt is None:
|
386 |
+
self.get_infobox_data(page.title)
|
387 |
+
else:
|
388 |
+
self.get_infobox_data(page.title,opt=opt)
|
389 |
+
|
390 |
+
for back in itertools.islice(page.backlinks, 10):
|
391 |
+
back = self.sanitize(back)
|
392 |
+
if ':' not in back:
|
393 |
+
link = self.sanitize(self.get_wikipedia_url(back))
|
394 |
+
if link not in links:
|
395 |
+
links.append(link)
|
396 |
+
self.info_packet['WIKI_GEO']['LINKS'][back] = link
|
397 |
+
|
398 |
+
|
399 |
+
def gather_geo(self, query,opt=None):
|
400 |
+
if opt is None:
|
401 |
+
self.info_packet['WIKI_GEO']['DATA'] = {}
|
402 |
+
else:
|
403 |
+
self.info_packet['WIKI_LOCALITY']['DATA'] = {}
|
404 |
+
|
405 |
+
wikipedia = WikipediaQueryRun(api_wrapper=WikipediaAPIWrapper())
|
406 |
+
|
407 |
+
result = wikipedia.run(query)
|
408 |
+
summary = result.split('Summary:')[1]
|
409 |
+
summary = self.sanitize(summary)
|
410 |
+
# print(result)
|
411 |
+
page_title = self.extract_page_title(result)
|
412 |
+
|
413 |
+
page = self.wiki_wiki.page(page_title)
|
414 |
+
|
415 |
+
# Do these first, they are less likely to fail
|
416 |
+
if opt is None:
|
417 |
+
self.info_packet['WIKI_GEO']['PAGE_LINK'] = self.get_wikipedia_url(page_title)
|
418 |
+
self.info_packet['WIKI_GEO']['PAGE_TITLE'] = page_title
|
419 |
+
self.info_packet['WIKI_GEO']['SUMMARY'] = summary
|
420 |
+
|
421 |
+
else:
|
422 |
+
self.info_packet['WIKI_LOCALITY']['PAGE_TITLE'] = page_title
|
423 |
+
self.info_packet['WIKI_LOCALITY']['PAGE_LINK'] = self.get_wikipedia_url(page_title)
|
424 |
+
self.info_packet['WIKI_LOCALITY']['SUMMARY'] = summary
|
425 |
+
|
426 |
+
|
427 |
+
# Check if the page exists, get the more complex data. Do it last in case of failure ########################## This might not be useful enough to justify the time
|
428 |
+
# if page.exists():
|
429 |
+
# if opt is None:
|
430 |
+
# self.extract_info_geo(page)
|
431 |
+
# else:
|
432 |
+
# self.extract_info_geo(page, opt=opt)
|
433 |
+
|
434 |
+
if opt is None:
|
435 |
+
self.info_packet['WIKI_GEO']['DATA'].update(self.infobox_data)
|
436 |
+
else:
|
437 |
+
self.info_packet['WIKI_LOCALITY']['DATA'].update(self.infobox_data_locality)
|
438 |
+
|
439 |
+
|
440 |
+
def gather_taxonomy(self, query):
|
441 |
+
wikipedia = WikipediaQueryRun(api_wrapper=WikipediaAPIWrapper())
|
442 |
+
|
443 |
+
# query = "Tracaulon sagittatum Tracaulon sagittatum"
|
444 |
+
result = wikipedia.run(query)
|
445 |
+
summary = result.split('Summary:')[1]
|
446 |
+
summary = self.sanitize(summary)
|
447 |
+
# print(result)
|
448 |
+
page_title = self.extract_page_title(result)
|
449 |
+
|
450 |
+
page = self.wiki_wiki.page(page_title)
|
451 |
+
|
452 |
+
# Check if the page exists
|
453 |
+
if page.exists():
|
454 |
+
self.extract_info_taxa(page)
|
455 |
+
|
456 |
+
self.info_packet['WIKI_TAXA']['PAGE_TITLE'] = page_title
|
457 |
+
self.info_packet['WIKI_TAXA']['PAGE_LINK'] = self.get_wikipedia_url(page_title)
|
458 |
+
self.info_packet['WIKI_TAXA']['SUMMARY'] = summary
|
459 |
+
return self.info_packet
|
460 |
+
|
461 |
+
|
462 |
+
def gather_wikipedia_results(self, output):
|
463 |
+
self.info_packet = {}
|
464 |
+
self.info_packet['WIKI_TAXA'] = {}
|
465 |
+
self.info_packet['WIKI_GEO'] = {}
|
466 |
+
self.info_packet['WIKI_LOCALITY'] = {}
|
467 |
+
|
468 |
+
municipality = output.get('municipality','')
|
469 |
+
county = output.get('county','')
|
470 |
+
stateProvince = output.get('stateProvince','')
|
471 |
+
country = output.get('country','')
|
472 |
+
|
473 |
+
locality = output.get('locality','')
|
474 |
+
|
475 |
+
order = output.get('order','')
|
476 |
+
family = output.get('family','')
|
477 |
+
scientificName = output.get('scientificName','')
|
478 |
+
genus = output.get('genus','')
|
479 |
+
specificEpithet = output.get('specificEpithet','')
|
480 |
+
|
481 |
+
|
482 |
+
query_geo = ' '.join([municipality, county, stateProvince, country]).strip()
|
483 |
+
query_locality = locality.strip()
|
484 |
+
query_taxa_primary = scientificName.strip()
|
485 |
+
query_taxa_secondary = ' '.join([genus, specificEpithet]).strip()
|
486 |
+
query_taxa_tertiary = ' '.join([order, family, genus, specificEpithet]).strip()
|
487 |
+
|
488 |
+
# query_taxa = "Tracaulon sagittatum Tracaulon sagittatum"
|
489 |
+
# query_geo = "Indiana Porter Co."
|
490 |
+
# query_locality = "Mical Springs edge"
|
491 |
+
|
492 |
+
if query_geo:
|
493 |
+
try:
|
494 |
+
self.gather_geo(query_geo)
|
495 |
+
except:
|
496 |
+
pass
|
497 |
+
|
498 |
+
if query_locality:
|
499 |
+
try:
|
500 |
+
self.gather_geo(query_locality,'locality')
|
501 |
+
except:
|
502 |
+
pass
|
503 |
+
|
504 |
+
queries_taxa = [query_taxa_primary, query_taxa_secondary, query_taxa_tertiary]
|
505 |
+
for q in queries_taxa:
|
506 |
+
if q:
|
507 |
+
try:
|
508 |
+
self.gather_taxonomy(q)
|
509 |
+
break
|
510 |
+
except:
|
511 |
+
pass
|
512 |
+
|
513 |
+
# print(self.info_packet)
|
514 |
+
# return self.info_packet
|
515 |
+
# self.gather_geo(query_geo)
|
516 |
+
try:
|
517 |
+
with open(self.json_file_path_wiki, 'w', encoding='utf-8') as file:
|
518 |
+
json.dump(self.info_packet, file, indent=4)
|
519 |
+
except:
|
520 |
+
sanitized_data = self.sanitize(self.info_packet)
|
521 |
+
with open(self.json_file_path_wiki, 'w', encoding='utf-8') as file:
|
522 |
+
json.dump(sanitized_data, file, indent=4)
|
523 |
+
|
524 |
+
|
525 |
+
def sanitize(self, data):
|
526 |
+
if isinstance(data, dict):
|
527 |
+
return {self.sanitize(key): self.sanitize(value) for key, value in data.items()}
|
528 |
+
elif isinstance(data, list):
|
529 |
+
return [self.sanitize(element) for element in data]
|
530 |
+
elif isinstance(data, str):
|
531 |
+
return data.encode('utf-8', 'ignore').decode('utf-8')
|
532 |
+
else:
|
533 |
+
return data
|
534 |
+
|
535 |
+
def remove_html_and_wiki_markup(self, text):
|
536 |
+
# Remove HTML tags
|
537 |
+
clean_text = re.sub(r'<.*?>', '', text)
|
538 |
+
|
539 |
+
# Remove Wiki links but keep the text inside
|
540 |
+
# For example, '[[Greg Abbott]]' becomes 'Greg Abbott'
|
541 |
+
clean_text = re.sub(r'\[\[(?:[^\]|]*\|)?([^\]|]*)\]\]', r'\1', clean_text)
|
542 |
+
|
543 |
+
# Remove Wiki template markup, e.g., '{{nowrap|text}}' becomes 'text'
|
544 |
+
clean_text = re.sub(r'\{\{(?:[^\}|]*\|)?([^\}|]*)\}\}', r'\1', clean_text)
|
545 |
+
|
546 |
+
return clean_text
|
547 |
+
|
548 |
+
|
549 |
+
if __name__ == '__main__':
|
550 |
+
test_output = {
|
551 |
+
"filename": "MICH_7375774_Polygonaceae_Persicaria_",
|
552 |
+
"catalogNumber": "1439649",
|
553 |
+
"order": "",
|
554 |
+
"family": "",
|
555 |
+
"scientificName": "Tracaulon sagittatum",
|
556 |
+
"scientificNameAuthorship": "",
|
557 |
+
"genus": "Tracaulon",
|
558 |
+
"subgenus": "",
|
559 |
+
"specificEpithet": "sagittatum",
|
560 |
+
"infraspecificEpithet": "",
|
561 |
+
"identifiedBy": "",
|
562 |
+
"recordedBy": "Marcus W. Lyon, Jr.",
|
563 |
+
"recordNumber": "TX 11",
|
564 |
+
"verbatimEventDate": "1927",
|
565 |
+
"eventDate": "1927-00-00",
|
566 |
+
"habitat": "wet subdunal woods",
|
567 |
+
"occurrenceRemarks": "Flowers pink",
|
568 |
+
"country": "Indiana",
|
569 |
+
"stateProvince": "Porter Co.",
|
570 |
+
"county": "",
|
571 |
+
"municipality": "",
|
572 |
+
"locality": "Mical Springs edge",
|
573 |
+
"degreeOfEstablishment": "",
|
574 |
+
"decimalLatitude": "",
|
575 |
+
"decimalLongitude": "",
|
576 |
+
"verbatimCoordinates": "",
|
577 |
+
"minimumElevationInMeters": "",
|
578 |
+
"maximumElevationInMeters": ""
|
579 |
+
}
|
580 |
+
Wiki = WikipediaLinks()
|
581 |
+
info_packet= Wiki.gather_wikipedia_results(test_output)
|
vouchervision/utils_LLM.py
CHANGED
@@ -1,5 +1,5 @@
|
|
1 |
# Helper funcs for LLM_XXXXX.py
|
2 |
-
import tiktoken, json, os
|
3 |
from langchain_core.output_parsers.format_instructions import JSON_FORMAT_INSTRUCTIONS
|
4 |
from transformers import AutoTokenizer
|
5 |
import GPUtil
|
@@ -7,6 +7,12 @@ import time
|
|
7 |
import psutil
|
8 |
import threading
|
9 |
import torch
|
|
|
|
|
|
|
|
|
|
|
|
|
10 |
|
11 |
|
12 |
def remove_colons_and_double_apostrophes(text):
|
@@ -45,6 +51,7 @@ class SystemLoadMonitor():
|
|
45 |
self.logger = logger
|
46 |
self.gpu_usage = {'max_cpu_usage': 0, 'max_load': 0, 'max_vram_usage': 0, "max_ram_usage": 0, 'monitoring': True}
|
47 |
self.start_time = None
|
|
|
48 |
self.has_GPU = torch.cuda.is_available()
|
49 |
self.monitor_interval = 2
|
50 |
|
@@ -53,6 +60,12 @@ class SystemLoadMonitor():
|
|
53 |
self.monitoring_thread = threading.Thread(target=self.monitor_usage, args=(self.monitor_interval,))
|
54 |
self.monitoring_thread.start()
|
55 |
|
|
|
|
|
|
|
|
|
|
|
|
|
56 |
def monitor_usage(self, interval):
|
57 |
while self.gpu_usage['monitoring']:
|
58 |
# GPU monitoring
|
@@ -73,18 +86,49 @@ class SystemLoadMonitor():
|
|
73 |
self.gpu_usage['max_cpu_usage'] = max(self.gpu_usage.get('max_cpu_usage', 0), cpu_usage)
|
74 |
time.sleep(interval)
|
75 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
76 |
def stop_monitoring_report_usage(self):
|
|
|
|
|
77 |
self.gpu_usage['monitoring'] = False
|
78 |
self.monitoring_thread.join()
|
79 |
-
|
80 |
-
self.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
81 |
|
82 |
self.logger.info(f"Max CPU Usage: {round(self.gpu_usage['max_cpu_usage'],2)}%")
|
83 |
self.logger.info(f"Max RAM Usage: {round(self.gpu_usage['max_ram_usage'],2)}GB")
|
84 |
|
85 |
if self.has_GPU:
|
|
|
|
|
|
|
86 |
self.logger.info(f"Max GPU Load: {round(self.gpu_usage['max_load']*100,2)}%")
|
87 |
self.logger.info(f"Max GPU Memory Usage: {round(self.gpu_usage['max_vram_usage'],2)}GB")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
88 |
|
89 |
|
90 |
|
|
|
1 |
# Helper funcs for LLM_XXXXX.py
|
2 |
+
import tiktoken, json, os, yaml
|
3 |
from langchain_core.output_parsers.format_instructions import JSON_FORMAT_INSTRUCTIONS
|
4 |
from transformers import AutoTokenizer
|
5 |
import GPUtil
|
|
|
7 |
import psutil
|
8 |
import threading
|
9 |
import torch
|
10 |
+
from datetime import datetime
|
11 |
+
|
12 |
+
def save_individual_prompt(prompt_template, txt_file_path_ind_prompt):
|
13 |
+
with open(txt_file_path_ind_prompt, 'w',encoding='utf-8') as file:
|
14 |
+
file.write(prompt_template)
|
15 |
+
|
16 |
|
17 |
|
18 |
def remove_colons_and_double_apostrophes(text):
|
|
|
51 |
self.logger = logger
|
52 |
self.gpu_usage = {'max_cpu_usage': 0, 'max_load': 0, 'max_vram_usage': 0, "max_ram_usage": 0, 'monitoring': True}
|
53 |
self.start_time = None
|
54 |
+
self.tool_start_time = None
|
55 |
self.has_GPU = torch.cuda.is_available()
|
56 |
self.monitor_interval = 2
|
57 |
|
|
|
60 |
self.monitoring_thread = threading.Thread(target=self.monitor_usage, args=(self.monitor_interval,))
|
61 |
self.monitoring_thread.start()
|
62 |
|
63 |
+
def stop_inference_timer(self):
|
64 |
+
# Stop inference timer and record elapsed time
|
65 |
+
self.inference_time = time.time() - self.start_time
|
66 |
+
# Immediately start the tool timer
|
67 |
+
self.tool_start_time = time.time()
|
68 |
+
|
69 |
def monitor_usage(self, interval):
|
70 |
while self.gpu_usage['monitoring']:
|
71 |
# GPU monitoring
|
|
|
86 |
self.gpu_usage['max_cpu_usage'] = max(self.gpu_usage.get('max_cpu_usage', 0), cpu_usage)
|
87 |
time.sleep(interval)
|
88 |
|
89 |
+
def get_current_datetime(self):
|
90 |
+
# Get the current date and time
|
91 |
+
now = datetime.now()
|
92 |
+
# Format it as a string, replacing colons with underscores
|
93 |
+
datetime_iso = now.strftime('%Y_%m_%dT%H_%M_%S')
|
94 |
+
return datetime_iso
|
95 |
+
|
96 |
def stop_monitoring_report_usage(self):
|
97 |
+
report = {}
|
98 |
+
|
99 |
self.gpu_usage['monitoring'] = False
|
100 |
self.monitoring_thread.join()
|
101 |
+
# Calculate tool time by checking if tool_start_time is set
|
102 |
+
if self.tool_start_time:
|
103 |
+
tool_time = time.time() - self.tool_start_time
|
104 |
+
else:
|
105 |
+
tool_time = 0
|
106 |
+
|
107 |
+
report = {'inference_time_s': str(round(self.inference_time,2)),
|
108 |
+
'tool_time_s': str(round(tool_time, 2)),
|
109 |
+
'max_cpu': str(round(self.gpu_usage['max_cpu_usage'],2)),
|
110 |
+
'max_ram_gb': str(round(self.gpu_usage['max_ram_usage'],2)),
|
111 |
+
'current_time': self.get_current_datetime(),
|
112 |
+
}
|
113 |
+
self.logger.info(f"Inference Time: {round(self.inference_time,2)} seconds")
|
114 |
+
self.logger.info(f"Tool Time: {round(tool_time,2)} seconds")
|
115 |
|
116 |
self.logger.info(f"Max CPU Usage: {round(self.gpu_usage['max_cpu_usage'],2)}%")
|
117 |
self.logger.info(f"Max RAM Usage: {round(self.gpu_usage['max_ram_usage'],2)}GB")
|
118 |
|
119 |
if self.has_GPU:
|
120 |
+
report.update({'max_gpu_load': str(round(self.gpu_usage['max_load']*100,2))})
|
121 |
+
report.update({'max_gpu_vram_gb': str(round(self.gpu_usage['max_vram_usage'],2))})
|
122 |
+
|
123 |
self.logger.info(f"Max GPU Load: {round(self.gpu_usage['max_load']*100,2)}%")
|
124 |
self.logger.info(f"Max GPU Memory Usage: {round(self.gpu_usage['max_vram_usage'],2)}GB")
|
125 |
+
else:
|
126 |
+
report.update({'max_gpu_load': str(0)})
|
127 |
+
report.update({'max_gpu_vram_gb': str(0)})
|
128 |
+
|
129 |
+
return report
|
130 |
+
|
131 |
+
|
132 |
|
133 |
|
134 |
|
vouchervision/utils_LLM_JSON_validation.py
CHANGED
@@ -14,18 +14,20 @@ def validate_and_align_JSON_keys_with_template(data, JSON_dict_structure):
|
|
14 |
if value.lower() in ['unknown', 'not provided', 'missing', 'na', 'none', 'n/a', 'null',
|
15 |
'not provided in the text', 'not found in the text',
|
16 |
'not in the text', 'not provided', 'not found',
|
17 |
-
'not provided in the
|
18 |
-
'not in the
|
19 |
-
'not provided in the
|
20 |
"not specified in the given text.",
|
21 |
"not specified in the given text",
|
22 |
"not specified in the text.",
|
23 |
"not specified in the text",
|
24 |
"not specified in text.",
|
25 |
"not specified in text",
|
26 |
-
"not specified in
|
27 |
"not specified",
|
28 |
-
'not in the
|
|
|
|
|
29 |
'n/a n/a','n/a, n/a',
|
30 |
'n/a, n/a, n/a','n/a n/a, n/a','n/a, n/a n/a','n/a n/a n/a',
|
31 |
'n/a, n/a, n/a, n/a','n/a n/a n/a n/a','n/a n/a, n/a, n/a','n/a, n/a n/a, n/a','n/a, n/a, n/a n/a',
|
|
|
14 |
if value.lower() in ['unknown', 'not provided', 'missing', 'na', 'none', 'n/a', 'null',
|
15 |
'not provided in the text', 'not found in the text',
|
16 |
'not in the text', 'not provided', 'not found',
|
17 |
+
'not provided in the ocr', 'not found in the ocr',
|
18 |
+
'not in the ocr',
|
19 |
+
'not provided in the ocr text', 'not found in the ocr text',
|
20 |
"not specified in the given text.",
|
21 |
"not specified in the given text",
|
22 |
"not specified in the text.",
|
23 |
"not specified in the text",
|
24 |
"not specified in text.",
|
25 |
"not specified in text",
|
26 |
+
"not specified in ocr",
|
27 |
"not specified",
|
28 |
+
'not in the ocr text',
|
29 |
+
'Not provided in ocr text',
|
30 |
+
'not provided in ocr text',
|
31 |
'n/a n/a','n/a, n/a',
|
32 |
'n/a, n/a, n/a','n/a n/a, n/a','n/a, n/a n/a','n/a n/a n/a',
|
33 |
'n/a, n/a, n/a, n/a','n/a n/a n/a n/a','n/a n/a, n/a, n/a','n/a, n/a n/a, n/a','n/a, n/a, n/a n/a',
|
vouchervision/utils_VoucherVision.py
CHANGED
@@ -46,6 +46,7 @@ class VoucherVision():
|
|
46 |
|
47 |
# self.trOCR_model_version = "microsoft/trocr-large-handwritten"
|
48 |
self.trOCR_model_version = "microsoft/trocr-base-handwritten"
|
|
|
49 |
self.trOCR_processor = None
|
50 |
self.trOCR_model = None
|
51 |
|
@@ -76,10 +77,12 @@ class VoucherVision():
|
|
76 |
"GEO_decimal_long","GEO_city", "GEO_county", "GEO_state",
|
77 |
"GEO_state_code", "GEO_country", "GEO_country_code", "GEO_continent",]
|
78 |
|
|
|
|
|
79 |
self.wfo_headers = ["WFO_override_OCR", "WFO_exact_match","WFO_exact_match_name","WFO_best_match","WFO_candidate_names","WFO_placement"]
|
80 |
self.wfo_headers_no_lists = ["WFO_override_OCR", "WFO_exact_match","WFO_exact_match_name","WFO_best_match","WFO_placement"]
|
81 |
|
82 |
-
self.utility_headers = ["filename"] + self.wfo_headers + self.geo_headers + ["tokens_in", "tokens_out", "path_to_crop","path_to_original","path_to_content","path_to_helper",]
|
83 |
# "WFO_override_OCR", "WFO_exact_match","WFO_exact_match_name","WFO_best_match","WFO_candidate_names","WFO_placement",
|
84 |
|
85 |
# "GEO_override_OCR", "GEO_method", "GEO_formatted_full_string", "GEO_decimal_lat",
|
@@ -294,7 +297,7 @@ class VoucherVision():
|
|
294 |
break
|
295 |
|
296 |
|
297 |
-
def add_data_to_excel_from_response(self, path_transcription, response, WFO_record, GEO_record, filename_without_extension, path_to_crop, path_to_content, path_to_helper, nt_in, nt_out):
|
298 |
|
299 |
|
300 |
wb = openpyxl.load_workbook(path_transcription)
|
@@ -359,6 +362,8 @@ class VoucherVision():
|
|
359 |
sheet.cell(row=next_row, column=i, value=nt_out)
|
360 |
elif header.value == "filename":
|
361 |
sheet.cell(row=next_row, column=i, value=filename_without_extension)
|
|
|
|
|
362 |
|
363 |
# "WFO_exact_match","WFO_exact_match_name","WFO_best_match","WFO_candidate_names","WFO_placement"
|
364 |
elif header.value in self.wfo_headers_no_lists:
|
@@ -385,6 +390,12 @@ class VoucherVision():
|
|
385 |
elif header.value in self.geo_headers:
|
386 |
sheet.cell(row=next_row, column=i, value=GEO_record.get(header.value, ''))
|
387 |
|
|
|
|
|
|
|
|
|
|
|
|
|
388 |
# save the workbook
|
389 |
wb.save(path_transcription)
|
390 |
|
@@ -396,7 +407,7 @@ class VoucherVision():
|
|
396 |
return False
|
397 |
|
398 |
|
399 |
-
def get_google_credentials(self):
|
400 |
if self.is_hf:
|
401 |
creds_json_str = os.getenv('GOOGLE_APPLICATION_CREDENTIALS')
|
402 |
credentials = service_account.Credentials.from_service_account_info(json.loads(creds_json_str))
|
@@ -651,6 +662,9 @@ class VoucherVision():
|
|
651 |
name_parts = model_name.split("_")
|
652 |
|
653 |
self.setup_JSON_dict_structure()
|
|
|
|
|
|
|
654 |
|
655 |
json_report.set_text(text_main=f'Loading {MODEL_NAME_FORMATTED}')
|
656 |
json_report.set_JSON({}, {}, {})
|
@@ -666,7 +680,7 @@ class VoucherVision():
|
|
666 |
paths = self.generate_paths(path_to_crop, i)
|
667 |
self.path_to_crop = path_to_crop
|
668 |
|
669 |
-
filename_without_extension, txt_file_path, txt_file_path_OCR, txt_file_path_OCR_bounds, jpg_file_path_OCR_helper = paths
|
670 |
json_report.set_text(text_main='Starting OCR')
|
671 |
self.perform_OCR_and_save_results(i, jpg_file_path_OCR_helper, txt_file_path_OCR, txt_file_path_OCR_bounds)
|
672 |
json_report.set_text(text_main='Finished OCR')
|
@@ -685,22 +699,22 @@ class VoucherVision():
|
|
685 |
self.logger.info(f'Waiting for {model_name} API call --- Using {MODEL_NAME_FORMATTED}')
|
686 |
|
687 |
if 'PALM2' in name_parts:
|
688 |
-
response_candidate, nt_in, nt_out, WFO_record, GEO_record = llm_model.call_llm_api_GooglePalm2(prompt, json_report)
|
689 |
|
690 |
elif 'GEMINI' in name_parts:
|
691 |
-
response_candidate, nt_in, nt_out, WFO_record, GEO_record = llm_model.call_llm_api_GoogleGemini(prompt, json_report)
|
692 |
|
693 |
elif 'MISTRAL' in name_parts and ('LOCAL' not in name_parts):
|
694 |
-
response_candidate, nt_in, nt_out, WFO_record, GEO_record = llm_model.call_llm_api_MistralAI(prompt, json_report)
|
695 |
|
696 |
elif 'LOCAL' in name_parts:
|
697 |
if 'MISTRAL' in name_parts or 'MIXTRAL' in name_parts:
|
698 |
if 'CPU' in name_parts:
|
699 |
-
response_candidate, nt_in, nt_out, WFO_record, GEO_record = llm_model.call_llm_local_cpu_MistralAI(prompt, json_report)
|
700 |
else:
|
701 |
-
response_candidate, nt_in, nt_out, WFO_record, GEO_record = llm_model.call_llm_local_MistralAI(prompt, json_report)
|
702 |
else:
|
703 |
-
response_candidate, nt_in, nt_out, WFO_record, GEO_record = llm_model.call_llm_api_OpenAI(prompt, json_report)
|
704 |
|
705 |
self.n_failed_LLM_calls += 1 if response_candidate is None else 0
|
706 |
|
@@ -710,7 +724,7 @@ class VoucherVision():
|
|
710 |
|
711 |
self.update_token_counters(nt_in, nt_out)
|
712 |
|
713 |
-
final_JSON_response, final_WFO_record, final_GEO_record = self.update_final_response(response_candidate, WFO_record, GEO_record, paths, path_to_crop, nt_in, nt_out)
|
714 |
|
715 |
self.log_completion_info(final_JSON_response)
|
716 |
|
@@ -779,14 +793,14 @@ class VoucherVision():
|
|
779 |
self.total_tokens_out += nt_out
|
780 |
|
781 |
|
782 |
-
def update_final_response(self, response_candidate, WFO_record, GEO_record, paths, path_to_crop, nt_in, nt_out):
|
783 |
-
filename_without_extension, txt_file_path, txt_file_path_OCR, txt_file_path_OCR_bounds, jpg_file_path_OCR_helper = paths
|
784 |
# Saving the JSON and XLSX files with the response and updating the final JSON response
|
785 |
if response_candidate is not None:
|
786 |
-
final_JSON_response_updated = self.save_json_and_xlsx(response_candidate, WFO_record, GEO_record, filename_without_extension, path_to_crop, txt_file_path, jpg_file_path_OCR_helper, nt_in, nt_out)
|
787 |
return final_JSON_response_updated, WFO_record, GEO_record
|
788 |
else:
|
789 |
-
final_JSON_response_updated = self.save_json_and_xlsx(response_candidate, WFO_record, GEO_record, filename_without_extension, path_to_crop, txt_file_path, jpg_file_path_OCR_helper, nt_in, nt_out)
|
790 |
return final_JSON_response_updated, WFO_record, GEO_record
|
791 |
|
792 |
|
@@ -814,13 +828,15 @@ class VoucherVision():
|
|
814 |
txt_file_path_OCR = os.path.join(self.Dirs.transcription_ind_OCR, filename_without_extension + '.json')
|
815 |
txt_file_path_OCR_bounds = os.path.join(self.Dirs.transcription_ind_OCR_bounds, filename_without_extension + '.json')
|
816 |
jpg_file_path_OCR_helper = os.path.join(self.Dirs.transcription_ind_OCR_helper, filename_without_extension + '.jpg')
|
|
|
|
|
817 |
|
818 |
self.logger.info(f'Working on {i+1}/{len(self.img_paths)} --- {filename_without_extension}')
|
819 |
|
820 |
-
return filename_without_extension, txt_file_path, txt_file_path_OCR, txt_file_path_OCR_bounds, jpg_file_path_OCR_helper
|
821 |
|
822 |
|
823 |
-
def save_json_and_xlsx(self, response, WFO_record, GEO_record, filename_without_extension, path_to_crop, txt_file_path, jpg_file_path_OCR_helper, nt_in, nt_out):
|
824 |
if response is None:
|
825 |
response = self.JSON_dict_structure
|
826 |
# Insert 'filename' as the first key
|
@@ -829,14 +845,14 @@ class VoucherVision():
|
|
829 |
|
830 |
# Then add the null info to the spreadsheet
|
831 |
response_null = self.create_null_row(filename_without_extension, path_to_crop, txt_file_path, jpg_file_path_OCR_helper)
|
832 |
-
self.add_data_to_excel_from_response(self.path_transcription, response_null, WFO_record, GEO_record, filename_without_extension, path_to_crop, txt_file_path, jpg_file_path_OCR_helper, nt_in=0, nt_out=0)
|
833 |
|
834 |
### Set completed JSON
|
835 |
else:
|
836 |
response = self.clean_catalog_number(response, filename_without_extension)
|
837 |
self.write_json_to_file(txt_file_path, response)
|
838 |
# add to the xlsx file
|
839 |
-
self.add_data_to_excel_from_response(self.path_transcription, response, WFO_record, GEO_record, filename_without_extension, path_to_crop, txt_file_path, jpg_file_path_OCR_helper, nt_in, nt_out)
|
840 |
return response
|
841 |
|
842 |
|
|
|
46 |
|
47 |
# self.trOCR_model_version = "microsoft/trocr-large-handwritten"
|
48 |
self.trOCR_model_version = "microsoft/trocr-base-handwritten"
|
49 |
+
# self.trOCR_model_version = "dh-unibe/trocr-medieval-escriptmask"
|
50 |
self.trOCR_processor = None
|
51 |
self.trOCR_model = None
|
52 |
|
|
|
77 |
"GEO_decimal_long","GEO_city", "GEO_county", "GEO_state",
|
78 |
"GEO_state_code", "GEO_country", "GEO_country_code", "GEO_continent",]
|
79 |
|
80 |
+
self.usage_headers = ["current_time", "inference_time_s", "tool_time_s","max_cpu", "max_ram_gb", "max_gpu_load", "max_gpu_vram_gb",]
|
81 |
+
|
82 |
self.wfo_headers = ["WFO_override_OCR", "WFO_exact_match","WFO_exact_match_name","WFO_best_match","WFO_candidate_names","WFO_placement"]
|
83 |
self.wfo_headers_no_lists = ["WFO_override_OCR", "WFO_exact_match","WFO_exact_match_name","WFO_best_match","WFO_placement"]
|
84 |
|
85 |
+
self.utility_headers = ["filename"] + self.wfo_headers + self.geo_headers + self.usage_headers + ["prompt", "LLM", "tokens_in", "tokens_out", "path_to_crop","path_to_original","path_to_content","path_to_helper",]
|
86 |
# "WFO_override_OCR", "WFO_exact_match","WFO_exact_match_name","WFO_best_match","WFO_candidate_names","WFO_placement",
|
87 |
|
88 |
# "GEO_override_OCR", "GEO_method", "GEO_formatted_full_string", "GEO_decimal_lat",
|
|
|
297 |
break
|
298 |
|
299 |
|
300 |
+
def add_data_to_excel_from_response(self, path_transcription, response, WFO_record, GEO_record, usage_report, MODEL_NAME_FORMATTED, filename_without_extension, path_to_crop, path_to_content, path_to_helper, nt_in, nt_out):
|
301 |
|
302 |
|
303 |
wb = openpyxl.load_workbook(path_transcription)
|
|
|
362 |
sheet.cell(row=next_row, column=i, value=nt_out)
|
363 |
elif header.value == "filename":
|
364 |
sheet.cell(row=next_row, column=i, value=filename_without_extension)
|
365 |
+
elif header.value == "prompt":
|
366 |
+
sheet.cell(row=next_row, column=i, value=os.path.basename(self.path_custom_prompts))
|
367 |
|
368 |
# "WFO_exact_match","WFO_exact_match_name","WFO_best_match","WFO_candidate_names","WFO_placement"
|
369 |
elif header.value in self.wfo_headers_no_lists:
|
|
|
390 |
elif header.value in self.geo_headers:
|
391 |
sheet.cell(row=next_row, column=i, value=GEO_record.get(header.value, ''))
|
392 |
|
393 |
+
elif header.value in self.usage_headers:
|
394 |
+
sheet.cell(row=next_row, column=i, value=usage_report.get(header.value, ''))
|
395 |
+
|
396 |
+
elif header.value == "LLM":
|
397 |
+
sheet.cell(row=next_row, column=i, value=MODEL_NAME_FORMATTED)
|
398 |
+
|
399 |
# save the workbook
|
400 |
wb.save(path_transcription)
|
401 |
|
|
|
407 |
return False
|
408 |
|
409 |
|
410 |
+
def get_google_credentials(self): # Also used for google drive
|
411 |
if self.is_hf:
|
412 |
creds_json_str = os.getenv('GOOGLE_APPLICATION_CREDENTIALS')
|
413 |
credentials = service_account.Credentials.from_service_account_info(json.loads(creds_json_str))
|
|
|
662 |
name_parts = model_name.split("_")
|
663 |
|
664 |
self.setup_JSON_dict_structure()
|
665 |
+
|
666 |
+
Copy_Prompt = PromptCatalog()
|
667 |
+
Copy_Prompt.copy_prompt_template_to_new_dir(self.Dirs.transcription_prompt, self.path_custom_prompts)
|
668 |
|
669 |
json_report.set_text(text_main=f'Loading {MODEL_NAME_FORMATTED}')
|
670 |
json_report.set_JSON({}, {}, {})
|
|
|
680 |
paths = self.generate_paths(path_to_crop, i)
|
681 |
self.path_to_crop = path_to_crop
|
682 |
|
683 |
+
filename_without_extension, txt_file_path, txt_file_path_OCR, txt_file_path_OCR_bounds, jpg_file_path_OCR_helper, json_file_path_wiki, txt_file_path_ind_prompt = paths
|
684 |
json_report.set_text(text_main='Starting OCR')
|
685 |
self.perform_OCR_and_save_results(i, jpg_file_path_OCR_helper, txt_file_path_OCR, txt_file_path_OCR_bounds)
|
686 |
json_report.set_text(text_main='Finished OCR')
|
|
|
699 |
self.logger.info(f'Waiting for {model_name} API call --- Using {MODEL_NAME_FORMATTED}')
|
700 |
|
701 |
if 'PALM2' in name_parts:
|
702 |
+
response_candidate, nt_in, nt_out, WFO_record, GEO_record, usage_report = llm_model.call_llm_api_GooglePalm2(prompt, json_report, paths)
|
703 |
|
704 |
elif 'GEMINI' in name_parts:
|
705 |
+
response_candidate, nt_in, nt_out, WFO_record, GEO_record, usage_report = llm_model.call_llm_api_GoogleGemini(prompt, json_report, paths)
|
706 |
|
707 |
elif 'MISTRAL' in name_parts and ('LOCAL' not in name_parts):
|
708 |
+
response_candidate, nt_in, nt_out, WFO_record, GEO_record, usage_report = llm_model.call_llm_api_MistralAI(prompt, json_report, paths)
|
709 |
|
710 |
elif 'LOCAL' in name_parts:
|
711 |
if 'MISTRAL' in name_parts or 'MIXTRAL' in name_parts:
|
712 |
if 'CPU' in name_parts:
|
713 |
+
response_candidate, nt_in, nt_out, WFO_record, GEO_record, usage_report = llm_model.call_llm_local_cpu_MistralAI(prompt, json_report, paths)
|
714 |
else:
|
715 |
+
response_candidate, nt_in, nt_out, WFO_record, GEO_record, usage_report = llm_model.call_llm_local_MistralAI(prompt, json_report, paths)
|
716 |
else:
|
717 |
+
response_candidate, nt_in, nt_out, WFO_record, GEO_record, usage_report = llm_model.call_llm_api_OpenAI(prompt, json_report, paths)
|
718 |
|
719 |
self.n_failed_LLM_calls += 1 if response_candidate is None else 0
|
720 |
|
|
|
724 |
|
725 |
self.update_token_counters(nt_in, nt_out)
|
726 |
|
727 |
+
final_JSON_response, final_WFO_record, final_GEO_record = self.update_final_response(response_candidate, WFO_record, GEO_record, usage_report, MODEL_NAME_FORMATTED, paths, path_to_crop, nt_in, nt_out)
|
728 |
|
729 |
self.log_completion_info(final_JSON_response)
|
730 |
|
|
|
793 |
self.total_tokens_out += nt_out
|
794 |
|
795 |
|
796 |
+
def update_final_response(self, response_candidate, WFO_record, GEO_record, usage_report, MODEL_NAME_FORMATTED, paths, path_to_crop, nt_in, nt_out):
|
797 |
+
filename_without_extension, txt_file_path, txt_file_path_OCR, txt_file_path_OCR_bounds, jpg_file_path_OCR_helper, json_file_path_wiki, txt_file_path_ind_prompt = paths
|
798 |
# Saving the JSON and XLSX files with the response and updating the final JSON response
|
799 |
if response_candidate is not None:
|
800 |
+
final_JSON_response_updated = self.save_json_and_xlsx(response_candidate, WFO_record, GEO_record, usage_report, MODEL_NAME_FORMATTED, filename_without_extension, path_to_crop, txt_file_path, jpg_file_path_OCR_helper, nt_in, nt_out)
|
801 |
return final_JSON_response_updated, WFO_record, GEO_record
|
802 |
else:
|
803 |
+
final_JSON_response_updated = self.save_json_and_xlsx(response_candidate, WFO_record, GEO_record, usage_report, MODEL_NAME_FORMATTED, filename_without_extension, path_to_crop, txt_file_path, jpg_file_path_OCR_helper, nt_in, nt_out)
|
804 |
return final_JSON_response_updated, WFO_record, GEO_record
|
805 |
|
806 |
|
|
|
828 |
txt_file_path_OCR = os.path.join(self.Dirs.transcription_ind_OCR, filename_without_extension + '.json')
|
829 |
txt_file_path_OCR_bounds = os.path.join(self.Dirs.transcription_ind_OCR_bounds, filename_without_extension + '.json')
|
830 |
jpg_file_path_OCR_helper = os.path.join(self.Dirs.transcription_ind_OCR_helper, filename_without_extension + '.jpg')
|
831 |
+
json_file_path_wiki = os.path.join(self.Dirs.transcription_ind_wiki, filename_without_extension + '.json')
|
832 |
+
txt_file_path_ind_prompt = os.path.join(self.Dirs.transcription_ind_prompt, filename_without_extension + '.txt')
|
833 |
|
834 |
self.logger.info(f'Working on {i+1}/{len(self.img_paths)} --- {filename_without_extension}')
|
835 |
|
836 |
+
return filename_without_extension, txt_file_path, txt_file_path_OCR, txt_file_path_OCR_bounds, jpg_file_path_OCR_helper, json_file_path_wiki, txt_file_path_ind_prompt
|
837 |
|
838 |
|
839 |
+
def save_json_and_xlsx(self, response, WFO_record, GEO_record, usage_report, MODEL_NAME_FORMATTED, filename_without_extension, path_to_crop, txt_file_path, jpg_file_path_OCR_helper, nt_in, nt_out):
|
840 |
if response is None:
|
841 |
response = self.JSON_dict_structure
|
842 |
# Insert 'filename' as the first key
|
|
|
845 |
|
846 |
# Then add the null info to the spreadsheet
|
847 |
response_null = self.create_null_row(filename_without_extension, path_to_crop, txt_file_path, jpg_file_path_OCR_helper)
|
848 |
+
self.add_data_to_excel_from_response(self.path_transcription, response_null, WFO_record, GEO_record, usage_report, MODEL_NAME_FORMATTED, filename_without_extension, path_to_crop, txt_file_path, jpg_file_path_OCR_helper, nt_in=0, nt_out=0)
|
849 |
|
850 |
### Set completed JSON
|
851 |
else:
|
852 |
response = self.clean_catalog_number(response, filename_without_extension)
|
853 |
self.write_json_to_file(txt_file_path, response)
|
854 |
# add to the xlsx file
|
855 |
+
self.add_data_to_excel_from_response(self.path_transcription, response, WFO_record, GEO_record, usage_report, MODEL_NAME_FORMATTED, filename_without_extension, path_to_crop, txt_file_path, jpg_file_path_OCR_helper, nt_in, nt_out)
|
856 |
return response
|
857 |
|
858 |
|
vouchervision/utils_hf.py
CHANGED
@@ -99,42 +99,60 @@ def check_prompt_yaml_filename(fname):
|
|
99 |
return False
|
100 |
|
101 |
# Function to upload files to Google Drive
|
102 |
-
def upload_to_drive(filepath, filename):
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
# Get the folder ID from the environment variable
|
113 |
-
folder_id = os.environ.get('GDRIVE_FOLDER_ID') # Renamed for clarity
|
114 |
-
|
115 |
-
if folder_id:
|
116 |
-
file_metadata = {
|
117 |
-
'name': filename,
|
118 |
-
'parents': [folder_id]
|
119 |
-
}
|
120 |
-
|
121 |
-
# Determine the mimetype based on the file extension
|
122 |
-
if filename.endswith('.yaml') or filename.endswith('.yml'):
|
123 |
-
mimetype = 'application/x-yaml'
|
124 |
-
elif filename.endswith('.zip'):
|
125 |
-
mimetype = 'application/zip'
|
126 |
else:
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
139 |
else:
|
140 |
-
|
|
|
|
|
|
|
|
|
|
|
|
99 |
return False
|
100 |
|
101 |
# Function to upload files to Google Drive
|
102 |
+
def upload_to_drive(filepath, filename, is_hf=True, cfg_private=None, do_upload = True):
|
103 |
+
if do_upload:
|
104 |
+
creds = get_google_credentials(is_hf=is_hf, cfg_private=cfg_private)
|
105 |
+
if creds:
|
106 |
+
service = build('drive', 'v3', credentials=creds)
|
107 |
+
|
108 |
+
# Get the folder ID from the environment variable
|
109 |
+
if is_hf:
|
110 |
+
folder_id = os.environ.get('GDRIVE_FOLDER_ID') # Renamed for clarity
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
111 |
else:
|
112 |
+
folder_id = cfg_private['google']['GDRIVE_FOLDER_ID'] # Renamed for clarity
|
113 |
+
|
114 |
+
|
115 |
+
if folder_id:
|
116 |
+
file_metadata = {
|
117 |
+
'name': filename,
|
118 |
+
'parents': [folder_id]
|
119 |
+
}
|
120 |
+
|
121 |
+
# Determine the mimetype based on the file extension
|
122 |
+
if filename.endswith('.yaml') or filename.endswith('.yml') or filepath.endswith('.yaml') or filepath.endswith('.yml'):
|
123 |
+
mimetype = 'application/x-yaml'
|
124 |
+
elif filepath.endswith('.zip'):
|
125 |
+
mimetype = 'application/zip'
|
126 |
+
else:
|
127 |
+
# Set a default mimetype if desired or handle the unsupported file type
|
128 |
+
print("Unsupported file type")
|
129 |
+
return None
|
130 |
+
|
131 |
+
# Upload the file
|
132 |
+
try:
|
133 |
+
media = MediaFileUpload(filepath, mimetype=mimetype)
|
134 |
+
file = service.files().create(
|
135 |
+
body=file_metadata,
|
136 |
+
media_body=media,
|
137 |
+
fields='id'
|
138 |
+
).execute()
|
139 |
+
print(f"Uploaded file with ID: {file.get('id')}")
|
140 |
+
except Exception as e:
|
141 |
+
msg = f"If the following error is '404 cannot find file...' then you need to share the GDRIVE folder with your Google API service account's email address. Open your Google API JSON file, find the email account that ends with '@developer.gserviceaccount.com', go to your Google Drive, share the folder with this email account. {e}"
|
142 |
+
print(msg)
|
143 |
+
raise Exception(msg)
|
144 |
+
else:
|
145 |
+
print("GDRIVE_API environment variable not set.")
|
146 |
+
|
147 |
+
def get_google_credentials(is_hf=True, cfg_private=None): # Also used for google drive
|
148 |
+
if is_hf:
|
149 |
+
creds_json_str = os.getenv('GOOGLE_APPLICATION_CREDENTIALS')
|
150 |
+
credentials = service_account.Credentials.from_service_account_info(json.loads(creds_json_str))
|
151 |
+
return credentials
|
152 |
else:
|
153 |
+
with open(cfg_private['google']['GOOGLE_APPLICATION_CREDENTIALS'], 'r') as file:
|
154 |
+
data = json.load(file)
|
155 |
+
creds_json_str = json.dumps(data)
|
156 |
+
credentials = service_account.Credentials.from_service_account_info(json.loads(creds_json_str))
|
157 |
+
os.environ['GOOGLE_APPLICATION_CREDENTIALS'] = creds_json_str
|
158 |
+
return credentials
|
vouchervision/vouchervision_main.py
CHANGED
@@ -79,20 +79,28 @@ def voucher_vision(cfg_file_path, dir_home, path_custom_prompts, cfg_test, progr
|
|
79 |
Voucher_Vision.close_logger_handlers()
|
80 |
|
81 |
zip_filepath = None
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
zip_filename = Dirs.run_name
|
86 |
|
87 |
-
|
88 |
-
|
89 |
-
|
|
|
|
|
|
|
90 |
|
91 |
return last_JSON_response, final_WFO_record, final_GEO_record, total_cost, Voucher_Vision.n_failed_OCR, Voucher_Vision.n_failed_LLM_calls, zip_filepath
|
92 |
|
93 |
-
def make_zipfile(
|
94 |
-
|
95 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
96 |
|
97 |
def voucher_vision_OCR_test(cfg_file_path, dir_home, cfg_test, path_to_crop):
|
98 |
# get_n_overall = progress_report.get_n_overall()
|
|
|
79 |
Voucher_Vision.close_logger_handlers()
|
80 |
|
81 |
zip_filepath = None
|
82 |
+
# Create Higging Face zip file
|
83 |
+
dir_to_zip = os.path.join(Dirs.dir_home, Dirs.run_name)
|
84 |
+
zip_filename = Dirs.run_name
|
|
|
85 |
|
86 |
+
# Creating a zip file
|
87 |
+
zip_filepath = make_zipfile(dir_to_zip, zip_filename) ####################################################################################################### TODO Make this configurable
|
88 |
+
if is_hf:
|
89 |
+
upload_to_drive(zip_filepath, zip_filename, is_hf, cfg_private=Voucher_Vision.cfg_private, do_upload=True) ###################################### TODO Make this configurable
|
90 |
+
else:
|
91 |
+
upload_to_drive(zip_filepath, zip_filename, is_hf, cfg_private=Voucher_Vision.cfg_private, do_upload=False) ##################################### TODO Make this configurable
|
92 |
|
93 |
return last_JSON_response, final_WFO_record, final_GEO_record, total_cost, Voucher_Vision.n_failed_OCR, Voucher_Vision.n_failed_LLM_calls, zip_filepath
|
94 |
|
95 |
+
def make_zipfile(base_dir, output_filename):
|
96 |
+
# Determine the directory where the zip file should be saved
|
97 |
+
# Construct the full path for the zip file
|
98 |
+
full_output_path = os.path.join(base_dir, output_filename)
|
99 |
+
# Create the zip archive
|
100 |
+
shutil.make_archive(full_output_path, 'zip', base_dir)
|
101 |
+
# Return the full path of the created zip file
|
102 |
+
return os.path.join(base_dir, output_filename + '.zip')
|
103 |
+
|
104 |
|
105 |
def voucher_vision_OCR_test(cfg_file_path, dir_home, cfg_test, path_to_crop):
|
106 |
# get_n_overall = progress_report.get_n_overall()
|