dlsmallw commited on
Commit
23428ec
·
1 Parent(s): 0637402

Task-359 Correct code to read new model repository structure

Browse files
Files changed (2) hide show
  1. app.py +22 -16
  2. scripts/predict.py +7 -28
app.py CHANGED
@@ -85,10 +85,7 @@ def load_inference_handler(api_token: str) -> InferenceHandler | None:
85
  Returns an instance of the InferenceHandler class if a valid token is entered, otherwise returns None.
86
  """
87
 
88
- try:
89
- return InferenceHandler(api_token)
90
- except:
91
- return None
92
 
93
  def build_result_tree(parent_elem, results: dict):
94
  """Loads the history of results from inference for previous inputs made by the user.
@@ -195,11 +192,10 @@ def analyze_text(input: str):
195
  input : str
196
  The text to analyze.
197
  """
198
- if ih:
199
  res = None
200
  with rc:
201
  with st.spinner("Processing...", show_time=True) as spnr:
202
- # time.sleep(5)
203
  res = ih.classify_text(input)
204
  del spnr
205
 
@@ -209,8 +205,8 @@ def analyze_text(input: str):
209
 
210
  @st.cache_data
211
  def load_datasets(_parent_elem, api_token: str):
212
- if api_token is None or len(api_token) == 0:
213
- raise Exception()
214
 
215
  cache_path = snapshot_download(repo_id=DATASET_REPO, repo_type='dataset', token=api_token)
216
  ds_record = pd.read_csv(os.path.join(cache_path, 'dataset_record.csv'))
@@ -263,13 +259,23 @@ def load_datasets(_parent_elem, api_token: str):
263
 
264
  st.title('NLPinitiative Text Classifier')
265
 
266
- st.sidebar.write("")
267
- API_KEY = st.sidebar.text_input(
268
- "Enter your HuggingFace API Token",
269
- help="You can get your free API token in your settings page: https://huggingface.co/settings/tokens",
270
- type="password",
271
- )
272
- ih = load_inference_handler(API_KEY)
 
 
 
 
 
 
 
 
 
 
273
 
274
  tab1 = st.empty()
275
  tab2 = st.empty()
@@ -354,7 +360,7 @@ with tab3:
354
  with tab4:
355
  ds_container = st.container(border=True)
356
  try:
357
- load_datasets(ds_container, API_KEY)
358
  except Exception as e:
359
  logger.error(f'{e}')
360
  ds_container.markdown(
 
85
  Returns an instance of the InferenceHandler class if a valid token is entered, otherwise returns None.
86
  """
87
 
88
+ return InferenceHandler(api_token)
 
 
 
89
 
90
  def build_result_tree(parent_elem, results: dict):
91
  """Loads the history of results from inference for previous inputs made by the user.
 
192
  input : str
193
  The text to analyze.
194
  """
195
+ if ih is not None:
196
  res = None
197
  with rc:
198
  with st.spinner("Processing...", show_time=True) as spnr:
 
199
  res = ih.classify_text(input)
200
  del spnr
201
 
 
205
 
206
  @st.cache_data
207
  def load_datasets(_parent_elem, api_token: str):
208
+ # if api_token is None or len(api_token) == 0:
209
+ # raise Exception()
210
 
211
  cache_path = snapshot_download(repo_id=DATASET_REPO, repo_type='dataset', token=api_token)
212
  ds_record = pd.read_csv(os.path.join(cache_path, 'dataset_record.csv'))
 
259
 
260
  st.title('NLPinitiative Text Classifier')
261
 
262
+ # st.sidebar.write("")
263
+ # API_KEY = st.sidebar.text_input(
264
+ # "Enter your HuggingFace API Token",
265
+ # help="You can get your free API token in your settings page: https://huggingface.co/settings/tokens",
266
+ # type="password",
267
+ # )
268
+
269
+ # if API_KEY is not None and len(API_KEY) > 0:
270
+ # try:
271
+ # ih = load_inference_handler(API_KEY)
272
+ # except Exception as e:
273
+ # ih = None
274
+ # st.sidebar.write(f'Failed to load inference handler: {e}')
275
+ # else:
276
+ # ih = None
277
+
278
+ ih = InferenceHandler(None)
279
 
280
  tab1 = st.empty()
281
  tab2 = st.empty()
 
360
  with tab4:
361
  ds_container = st.container(border=True)
362
  try:
363
+ load_datasets(ds_container, None)
364
  except Exception as e:
365
  logger.error(f'{e}')
366
  ds_container.markdown(
scripts/predict.py CHANGED
@@ -3,18 +3,18 @@ Script file used for performing inference with an existing model.
3
  """
4
 
5
  import torch
6
- import json
7
  import nltk
8
  from nltk.tokenize import sent_tokenize
9
- import huggingface_hub
10
 
11
  from transformers import (
12
  AutoTokenizer,
13
  AutoModelForSequenceClassification
14
  )
15
 
16
- BIN_REPO = 'dlsmallw/NLPinitiative-Binary-Classification'
17
- ML_REPO = 'dlsmallw/NLPinitiative-Multilabel-Regression'
 
 
18
 
19
  class InferenceHandler:
20
  """A class that handles performing inference using the trained binary classification and multilabel regression models."""
@@ -33,28 +33,13 @@ class InferenceHandler:
33
  self.ml_regr_tokenizer, self.ml_regr_model = self._init_model_and_tokenizer(ML_REPO)
34
  nltk.download('punkt_tab')
35
 
36
- def _get_config(self, repo_id: str) -> str:
37
- """Retrieves the config.json file from the specified model repository.
38
-
39
- Parameters
40
- ----------
41
- repo_id : str
42
- The repository id (i.e., <owner username>/<repository name>).
43
-
44
- """
45
-
46
- config = None
47
- if repo_id and self.api_token:
48
- config = huggingface_hub.hf_hub_download(repo_id, filename='config.json', token=self.api_token)
49
- return config
50
-
51
  def _init_model_and_tokenizer(self, repo_id: str):
52
  """Initializes a model and tokenizer for use in inference using the models path.
53
 
54
  Parameters
55
  ----------
56
- model_path : Path
57
- Directory path to the models tensor file.
58
 
59
  Returns
60
  -------
@@ -62,14 +47,8 @@ class InferenceHandler:
62
  A tuple containing the tokenizer and model objects.
63
  """
64
 
65
- config = self._get_config(repo_id)
66
- with open(config) as config_file:
67
- config_json = json.load(config_file)
68
- model_name = config_json['_name_or_path']
69
-
70
- tokenizer = AutoTokenizer.from_pretrained(model_name)
71
  model = AutoModelForSequenceClassification.from_pretrained(repo_id, token=self.api_token)
72
-
73
  model.eval()
74
  return tokenizer, model
75
 
 
3
  """
4
 
5
  import torch
 
6
  import nltk
7
  from nltk.tokenize import sent_tokenize
 
8
 
9
  from transformers import (
10
  AutoTokenizer,
11
  AutoModelForSequenceClassification
12
  )
13
 
14
+ from scripts.config import (
15
+ BIN_REPO,
16
+ ML_REPO
17
+ )
18
 
19
  class InferenceHandler:
20
  """A class that handles performing inference using the trained binary classification and multilabel regression models."""
 
33
  self.ml_regr_tokenizer, self.ml_regr_model = self._init_model_and_tokenizer(ML_REPO)
34
  nltk.download('punkt_tab')
35
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
  def _init_model_and_tokenizer(self, repo_id: str):
37
  """Initializes a model and tokenizer for use in inference using the models path.
38
 
39
  Parameters
40
  ----------
41
+ repo_id : str
42
+ The repository id (i.e., <owner username>/<repository name>).
43
 
44
  Returns
45
  -------
 
47
  A tuple containing the tokenizer and model objects.
48
  """
49
 
50
+ tokenizer = AutoTokenizer.from_pretrained(repo_id, token=self.api_token)
 
 
 
 
 
51
  model = AutoModelForSequenceClassification.from_pretrained(repo_id, token=self.api_token)
 
52
  model.eval()
53
  return tokenizer, model
54