akdeniz27 commited on
Commit
4874aa0
1 Parent(s): 557a1ee

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +23 -10
app.py CHANGED
@@ -30,19 +30,22 @@ model_checkpoint = st.sidebar.radio("", model_list)
30
  st.sidebar.write("For details of models: 'https://huggingface.co/akdeniz27/")
31
  st.sidebar.write("")
32
 
33
- xlm_agg_strategy_info = "'aggregation_strategy' can be selected as 'simple' or 'none' for 'xlm-roberta' because of the RoBERTa model's tokenization approach."
34
 
35
- st.sidebar.header("Select Aggregation Strategy Type")
36
  if model_checkpoint == "akdeniz27/xlm-roberta-base-turkish-ner":
37
- aggregation = st.sidebar.radio("", ('simple', 'none'))
38
- st.sidebar.write(xlm_agg_strategy_info)
39
- elif model_checkpoint == "xlm-roberta-large-finetuned-conll03-english": # or model_checkpoint == "tner/tner-xlm-roberta-base-ontonotes5":
40
- aggregation = st.sidebar.radio("", ('simple', 'none'))
41
- st.sidebar.write(xlm_agg_strategy_info)
 
 
42
  st.sidebar.write("")
43
  st.sidebar.write("This English NER model is included just to show the zero-shot transfer learning capability of XLM-Roberta.")
44
  else:
45
- aggregation = st.sidebar.radio("", ('first', 'simple', 'average', 'max', 'none'))
 
46
 
47
  st.sidebar.write("Please refer 'https://huggingface.co/transformers/_modules/transformers/pipelines/token_classification.html' for entity grouping with aggregation_strategy parameter.")
48
 
@@ -73,8 +76,18 @@ if Run_Button == True:
73
 
74
  ner_pipeline = setModel(model_checkpoint, aggregation)
75
  output = ner_pipeline(input_text)
 
 
 
 
 
 
 
 
 
 
76
 
77
- df = pd.DataFrame.from_dict(output)
78
  if aggregation != "none":
79
  cols_to_keep = ['word','entity_group','score','start','end']
80
  else:
@@ -90,7 +103,7 @@ if Run_Button == True:
90
  spacy_display["text"] = input_text
91
  spacy_display["title"] = None
92
 
93
- for entity in output:
94
  if aggregation != "none":
95
  spacy_display["ents"].append({"start": entity["start"], "end": entity["end"], "label": entity["entity_group"]})
96
  else:
 
30
  st.sidebar.write("For details of models: 'https://huggingface.co/akdeniz27/")
31
  st.sidebar.write("")
32
 
33
+ # xlm_agg_strategy_info = "'aggregation_strategy' can be selected as 'simple' or 'none' for 'xlm-roberta' because of the RoBERTa model's tokenization approach."
34
 
35
+ # st.sidebar.header("Select Aggregation Strategy Type")
36
  if model_checkpoint == "akdeniz27/xlm-roberta-base-turkish-ner":
37
+ aggregation = "simple"
38
+ # aggregation = st.sidebar.radio("", ('simple', 'none'))
39
+ # st.sidebar.write(xlm_agg_strategy_info)
40
+ elif model_checkpoint == "xlm-roberta-large-finetuned-conll03-english" or model_checkpoint == "tner/tner-xlm-roberta-base-ontonotes5":
41
+ aggregation = "simple"
42
+ # aggregation = st.sidebar.radio("", ('simple', 'none'))
43
+ # st.sidebar.write(xlm_agg_strategy_info)
44
  st.sidebar.write("")
45
  st.sidebar.write("This English NER model is included just to show the zero-shot transfer learning capability of XLM-Roberta.")
46
  else:
47
+ aggregation = "first"
48
+ # aggregation = st.sidebar.radio("", ('first', 'simple', 'average', 'max', 'none'))
49
 
50
  st.sidebar.write("Please refer 'https://huggingface.co/transformers/_modules/transformers/pipelines/token_classification.html' for entity grouping with aggregation_strategy parameter.")
51
 
 
76
 
77
  ner_pipeline = setModel(model_checkpoint, aggregation)
78
  output = ner_pipeline(input_text)
79
+
80
+ output_comb = []
81
+ for ind, entity in enumerate(output):
82
+ if ind == 0:
83
+ output_comb.append(entity)
84
+ elif output[ind]["start"] == output[ind-1]["end"]:
85
+ output_comb[ind-1]["entity"] = output_comb[ind-1]["entity"] + output[ind]["entity"]
86
+ output_comb[ind-1]["end"] = output[ind]["end"]
87
+ else:
88
+ output_comb.append(entity)
89
 
90
+ df = pd.DataFrame.from_dict(output_comb)
91
  if aggregation != "none":
92
  cols_to_keep = ['word','entity_group','score','start','end']
93
  else:
 
103
  spacy_display["text"] = input_text
104
  spacy_display["title"] = None
105
 
106
+ for entity in output_comb:
107
  if aggregation != "none":
108
  spacy_display["ents"].append({"start": entity["start"], "end": entity["end"], "label": entity["entity_group"]})
109
  else: