jschwaller commited on
Commit
53134ec
·
verified ·
1 Parent(s): 9c9bf1a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -3
app.py CHANGED
@@ -29,16 +29,25 @@ ner_tokenizer = AutoTokenizer.from_pretrained("d4data/biomedical-ner-all")
29
  ner_model = AutoModelForTokenClassification.from_pretrained("d4data/biomedical-ner-all")
30
 
31
  ner_pipe = pipeline("ner", model=ner_model, tokenizer=ner_tokenizer, aggregation_strategy="simple") # pass device=0 if using gpu
32
- #
 
 
 
 
 
33
 
34
  def adr_predict(x):
35
  encoded_input = tokenizer(x, return_tensors='pt')
36
  output = model(**encoded_input)
37
  scores = output[0][0].detach()
38
  scores = torch.nn.functional.softmax(scores)
39
-
 
40
  shap_values = explainer([str(x).lower()])
41
- local_plot = shap.plots.text(shap_values[0], display=False)
 
 
 
42
 
43
  res = ner_pipe(x)
44
  entity_colors = {
 
29
  ner_model = AutoModelForTokenClassification.from_pretrained("d4data/biomedical-ner-all")
30
 
31
  ner_pipe = pipeline("ner", model=ner_model, tokenizer=ner_tokenizer, aggregation_strategy="simple") # pass device=0 if using gpu
32
+
33
+ # fixing the colors
34
+ # Create a custom color map
35
+ cmap = {'0': '#457B9D', # Non-severe reactions in blue
36
+ '1': '#E63946'} # Severe reactions in red
37
+
38
 
39
  def adr_predict(x):
40
  encoded_input = tokenizer(x, return_tensors='pt')
41
  output = model(**encoded_input)
42
  scores = output[0][0].detach()
43
  scores = torch.nn.functional.softmax(scores)
44
+
45
+ # Generate SHAP values and use the custom color map
46
  shap_values = explainer([str(x).lower()])
47
+
48
+ # Ensure the color depends on the output class; customize as needed
49
+ base_colors = {label: cmap[str(label)] for label in range(len(scores))}
50
+ shap.plots.text(shap_values[0], color=base_colors, display=False)
51
 
52
  res = ner_pipe(x)
53
  entity_colors = {