Task-359 Correct code to read new model repository structure
Browse files- app.py +22 -16
- scripts/predict.py +7 -28
app.py
CHANGED
@@ -85,10 +85,7 @@ def load_inference_handler(api_token: str) -> InferenceHandler | None:
|
|
85 |
Returns an instance of the InferenceHandler class if a valid token is entered, otherwise returns None.
|
86 |
"""
|
87 |
|
88 |
-
|
89 |
-
return InferenceHandler(api_token)
|
90 |
-
except:
|
91 |
-
return None
|
92 |
|
93 |
def build_result_tree(parent_elem, results: dict):
|
94 |
"""Loads the history of results from inference for previous inputs made by the user.
|
@@ -195,11 +192,10 @@ def analyze_text(input: str):
|
|
195 |
input : str
|
196 |
The text to analyze.
|
197 |
"""
|
198 |
-
if ih:
|
199 |
res = None
|
200 |
with rc:
|
201 |
with st.spinner("Processing...", show_time=True) as spnr:
|
202 |
-
# time.sleep(5)
|
203 |
res = ih.classify_text(input)
|
204 |
del spnr
|
205 |
|
@@ -209,8 +205,8 @@ def analyze_text(input: str):
|
|
209 |
|
210 |
@st.cache_data
|
211 |
def load_datasets(_parent_elem, api_token: str):
|
212 |
-
if api_token is None or len(api_token) == 0:
|
213 |
-
|
214 |
|
215 |
cache_path = snapshot_download(repo_id=DATASET_REPO, repo_type='dataset', token=api_token)
|
216 |
ds_record = pd.read_csv(os.path.join(cache_path, 'dataset_record.csv'))
|
@@ -263,13 +259,23 @@ def load_datasets(_parent_elem, api_token: str):
|
|
263 |
|
264 |
st.title('NLPinitiative Text Classifier')
|
265 |
|
266 |
-
st.sidebar.write("")
|
267 |
-
API_KEY = st.sidebar.text_input(
|
268 |
-
|
269 |
-
|
270 |
-
|
271 |
-
)
|
272 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
273 |
|
274 |
tab1 = st.empty()
|
275 |
tab2 = st.empty()
|
@@ -354,7 +360,7 @@ with tab3:
|
|
354 |
with tab4:
|
355 |
ds_container = st.container(border=True)
|
356 |
try:
|
357 |
-
load_datasets(ds_container,
|
358 |
except Exception as e:
|
359 |
logger.error(f'{e}')
|
360 |
ds_container.markdown(
|
|
|
85 |
Returns an instance of the InferenceHandler class if a valid token is entered, otherwise returns None.
|
86 |
"""
|
87 |
|
88 |
+
return InferenceHandler(api_token)
|
|
|
|
|
|
|
89 |
|
90 |
def build_result_tree(parent_elem, results: dict):
|
91 |
"""Loads the history of results from inference for previous inputs made by the user.
|
|
|
192 |
input : str
|
193 |
The text to analyze.
|
194 |
"""
|
195 |
+
if ih is not None:
|
196 |
res = None
|
197 |
with rc:
|
198 |
with st.spinner("Processing...", show_time=True) as spnr:
|
|
|
199 |
res = ih.classify_text(input)
|
200 |
del spnr
|
201 |
|
|
|
205 |
|
206 |
@st.cache_data
|
207 |
def load_datasets(_parent_elem, api_token: str):
|
208 |
+
# if api_token is None or len(api_token) == 0:
|
209 |
+
# raise Exception()
|
210 |
|
211 |
cache_path = snapshot_download(repo_id=DATASET_REPO, repo_type='dataset', token=api_token)
|
212 |
ds_record = pd.read_csv(os.path.join(cache_path, 'dataset_record.csv'))
|
|
|
259 |
|
260 |
st.title('NLPinitiative Text Classifier')
|
261 |
|
262 |
+
# st.sidebar.write("")
|
263 |
+
# API_KEY = st.sidebar.text_input(
|
264 |
+
# "Enter your HuggingFace API Token",
|
265 |
+
# help="You can get your free API token in your settings page: https://huggingface.co/settings/tokens",
|
266 |
+
# type="password",
|
267 |
+
# )
|
268 |
+
|
269 |
+
# if API_KEY is not None and len(API_KEY) > 0:
|
270 |
+
# try:
|
271 |
+
# ih = load_inference_handler(API_KEY)
|
272 |
+
# except Exception as e:
|
273 |
+
# ih = None
|
274 |
+
# st.sidebar.write(f'Failed to load inference handler: {e}')
|
275 |
+
# else:
|
276 |
+
# ih = None
|
277 |
+
|
278 |
+
ih = InferenceHandler(None)
|
279 |
|
280 |
tab1 = st.empty()
|
281 |
tab2 = st.empty()
|
|
|
360 |
with tab4:
|
361 |
ds_container = st.container(border=True)
|
362 |
try:
|
363 |
+
load_datasets(ds_container, None)
|
364 |
except Exception as e:
|
365 |
logger.error(f'{e}')
|
366 |
ds_container.markdown(
|
scripts/predict.py
CHANGED
@@ -3,18 +3,18 @@ Script file used for performing inference with an existing model.
|
|
3 |
"""
|
4 |
|
5 |
import torch
|
6 |
-
import json
|
7 |
import nltk
|
8 |
from nltk.tokenize import sent_tokenize
|
9 |
-
import huggingface_hub
|
10 |
|
11 |
from transformers import (
|
12 |
AutoTokenizer,
|
13 |
AutoModelForSequenceClassification
|
14 |
)
|
15 |
|
16 |
-
|
17 |
-
|
|
|
|
|
18 |
|
19 |
class InferenceHandler:
|
20 |
"""A class that handles performing inference using the trained binary classification and multilabel regression models."""
|
@@ -33,28 +33,13 @@ class InferenceHandler:
|
|
33 |
self.ml_regr_tokenizer, self.ml_regr_model = self._init_model_and_tokenizer(ML_REPO)
|
34 |
nltk.download('punkt_tab')
|
35 |
|
36 |
-
def _get_config(self, repo_id: str) -> str:
|
37 |
-
"""Retrieves the config.json file from the specified model repository.
|
38 |
-
|
39 |
-
Parameters
|
40 |
-
----------
|
41 |
-
repo_id : str
|
42 |
-
The repository id (i.e., <owner username>/<repository name>).
|
43 |
-
|
44 |
-
"""
|
45 |
-
|
46 |
-
config = None
|
47 |
-
if repo_id and self.api_token:
|
48 |
-
config = huggingface_hub.hf_hub_download(repo_id, filename='config.json', token=self.api_token)
|
49 |
-
return config
|
50 |
-
|
51 |
def _init_model_and_tokenizer(self, repo_id: str):
|
52 |
"""Initializes a model and tokenizer for use in inference using the models path.
|
53 |
|
54 |
Parameters
|
55 |
----------
|
56 |
-
|
57 |
-
|
58 |
|
59 |
Returns
|
60 |
-------
|
@@ -62,14 +47,8 @@ class InferenceHandler:
|
|
62 |
A tuple containing the tokenizer and model objects.
|
63 |
"""
|
64 |
|
65 |
-
|
66 |
-
with open(config) as config_file:
|
67 |
-
config_json = json.load(config_file)
|
68 |
-
model_name = config_json['_name_or_path']
|
69 |
-
|
70 |
-
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
71 |
model = AutoModelForSequenceClassification.from_pretrained(repo_id, token=self.api_token)
|
72 |
-
|
73 |
model.eval()
|
74 |
return tokenizer, model
|
75 |
|
|
|
3 |
"""
|
4 |
|
5 |
import torch
|
|
|
6 |
import nltk
|
7 |
from nltk.tokenize import sent_tokenize
|
|
|
8 |
|
9 |
from transformers import (
|
10 |
AutoTokenizer,
|
11 |
AutoModelForSequenceClassification
|
12 |
)
|
13 |
|
14 |
+
from scripts.config import (
|
15 |
+
BIN_REPO,
|
16 |
+
ML_REPO
|
17 |
+
)
|
18 |
|
19 |
class InferenceHandler:
|
20 |
"""A class that handles performing inference using the trained binary classification and multilabel regression models."""
|
|
|
33 |
self.ml_regr_tokenizer, self.ml_regr_model = self._init_model_and_tokenizer(ML_REPO)
|
34 |
nltk.download('punkt_tab')
|
35 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
36 |
def _init_model_and_tokenizer(self, repo_id: str):
|
37 |
"""Initializes a model and tokenizer for use in inference using the models path.
|
38 |
|
39 |
Parameters
|
40 |
----------
|
41 |
+
repo_id : str
|
42 |
+
The repository id (i.e., <owner username>/<repository name>).
|
43 |
|
44 |
Returns
|
45 |
-------
|
|
|
47 |
A tuple containing the tokenizer and model objects.
|
48 |
"""
|
49 |
|
50 |
+
tokenizer = AutoTokenizer.from_pretrained(repo_id, token=self.api_token)
|
|
|
|
|
|
|
|
|
|
|
51 |
model = AutoModelForSequenceClassification.from_pretrained(repo_id, token=self.api_token)
|
|
|
52 |
model.eval()
|
53 |
return tokenizer, model
|
54 |
|