gianma commited on
Commit
b514106
·
1 Parent(s): d44f834

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -4
app.py CHANGED
@@ -5,14 +5,14 @@ from libs.eurovoc.EurovocTh import EurovocTh
5
 
6
  api_key = environ.get("api_key")
7
 
8
- def app(input, filter_strategy, relevence_threshold, k, length_management_strategy, token_overlap, results_merge_stategy, drop_last):
9
 
10
  th = EurovocTh()
11
 
12
- domain_classifier_pipeline = pipeline(model="gianma/classifierEUtopLevelLongerTrain", tokenizer="gianma/classifierEUtopLevelLongerTrain", use_auth_token=api_key)
13
  length_limit = domain_classifier_pipeline.tokenizer.model_max_length
14
 
15
- eu_th_classifier_pipeline = pipeline(model="gianma/classifierEUtopLevelLongerTrain", tokenizer='gianma/classifierEUtopLevelLongerTrain', use_auth_token=api_key)
16
  length_limit = eu_th_classifier_pipeline.tokenizer.model_max_length
17
 
18
  kwargs = {'padding':True,'truncation':True,'max_length':length_limit}
@@ -207,7 +207,9 @@ with gr.Blocks() as interface:
207
  with gr.Column():
208
  outputs = gr.Label()
209
 
 
 
210
  clear_button.click(fn=lambda: None, inputs=None, outputs=input_text)
211
- classify_button.click(app, inputs=[input_text,filter_strategy, s_confidence, s_k, document_reading_strategy, chunk_overlap, combine_strategy, exclude_last], outputs=outputs)
212
 
213
  interface.launch()
 
5
 
6
  api_key = environ.get("api_key")
7
 
8
+ def app(input, filter_strategy, relevence_threshold, k, length_management_strategy, token_overlap, results_merge_stategy, drop_last, model_name):
9
 
10
  th = EurovocTh()
11
 
12
+ domain_classifier_pipeline = pipeline(model=model_name, tokenizer=model_name, use_auth_token=api_key)
13
  length_limit = domain_classifier_pipeline.tokenizer.model_max_length
14
 
15
+ eu_th_classifier_pipeline = pipeline(model=model_name, tokenizer=model_name, use_auth_token=api_key)
16
  length_limit = eu_th_classifier_pipeline.tokenizer.model_max_length
17
 
18
  kwargs = {'padding':True,'truncation':True,'max_length':length_limit}
 
207
  with gr.Column():
208
  outputs = gr.Label()
209
 
210
+ model_name = "gianma/classifierEUtopLevelLongerTrain"
211
+
212
  clear_button.click(fn=lambda: None, inputs=None, outputs=input_text)
213
+ classify_button.click(app, inputs=[input_text,filter_strategy, s_confidence, s_k, document_reading_strategy, chunk_overlap, combine_strategy, exclude_last, model_name], outputs=outputs)
214
 
215
  interface.launch()