jschwaller commited on
Commit
9f1b256
·
verified ·
1 Parent(s): 53134ec

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +22 -26
app.py CHANGED
@@ -5,7 +5,6 @@ import numpy as np
5
  import scipy as sp
6
  import torch
7
  import transformers
8
- from transformers import pipeline
9
  from transformers import AutoTokenizer, AutoModelForSequenceClassification, AutoModelForTokenClassification
10
 
11
  import matplotlib.pyplot as plt
@@ -19,45 +18,33 @@ device = "cuda:0" if torch.cuda.is_available() else "cpu"
19
  tokenizer = AutoTokenizer.from_pretrained("jschwaller/ADRv2024")
20
  model = AutoModelForSequenceClassification.from_pretrained("jschwaller/ADRv2024")
21
 
22
- # Build a pipeline object for predictions
23
- pred = transformers.pipeline("text-classification", model=model,
24
- tokenizer=tokenizer, return_all_scores=True)
25
 
26
  explainer = shap.Explainer(pred)
27
 
28
  ner_tokenizer = AutoTokenizer.from_pretrained("d4data/biomedical-ner-all")
29
  ner_model = AutoModelForTokenClassification.from_pretrained("d4data/biomedical-ner-all")
30
 
31
- ner_pipe = pipeline("ner", model=ner_model, tokenizer=ner_tokenizer, aggregation_strategy="simple") # pass device=0 if using gpu
32
-
33
- # fixing the colors
34
- # Create a custom color map
35
- cmap = {'0': '#457B9D', # Non-severe reactions in blue
36
- '1': '#E63946'} # Severe reactions in red
37
-
38
 
39
  def adr_predict(x):
40
  encoded_input = tokenizer(x, return_tensors='pt')
41
  output = model(**encoded_input)
42
  scores = output[0][0].detach()
43
  scores = torch.nn.functional.softmax(scores)
44
-
45
- # Generate SHAP values and use the custom color map
46
  shap_values = explainer([str(x).lower()])
47
-
48
- # Ensure the color depends on the output class; customize as needed
49
- base_colors = {label: cmap[str(label)] for label in range(len(scores))}
50
- shap.plots.text(shap_values[0], color=base_colors, display=False)
51
 
52
  res = ner_pipe(x)
53
  entity_colors = {
54
- 'Severity': '#E63946', # a vivid red
55
- 'Sign_symptom': '#2A9D8F', # a deep teal
56
- 'Medication': '#457B9D', # a dusky blue
57
- 'Age': '#F4A261', # a sandy orange
58
- 'Sex': '#F4A261', # same sandy orange for consistency with 'Age'
59
- 'Diagnostic_procedure': '#9C6644', # a brown
60
- 'Biological_structure': '#BDB2FF', # a light pastel purple
61
  }
62
 
63
  htext = ""
@@ -84,7 +71,7 @@ description1 = "This app takes text (up to a few sentences) and predicts to what
84
 
85
  css = """
86
  body { font-family: 'Roboto', sans-serif; background-color: #333; color: #87CEEB; }
87
- h1, h2, h3, h4, h5, h6, p, label, .markdown { color: #87CEEB; } /* Ensuring that all text elements are consistently light blue */
88
  .textbox { width: 100%; border-radius: 10px; border: 1px solid #ccc; background-color: white; color: black; }
89
  .button { background-color: #FF6347; color: white; border: none; border-radius: 10px; padding: 10px 20px; cursor: pointer; }
90
  """
@@ -102,7 +89,14 @@ with gr.Blocks(css=css) as demo:
102
  with gr.Column(visible=True):
103
  local_plot = gr.HTML(label='Shap:')
104
  htext = gr.HTML(label="NER")
105
-
 
 
 
 
 
 
 
106
  submit_btn.click(
107
  main,
108
  [prob1],
@@ -110,6 +104,8 @@ with gr.Blocks(css=css) as demo:
110
  api_name="adr"
111
  )
112
 
 
 
113
  with gr.Row():
114
  gr.Markdown("### Click on any of the examples below to see how it works:")
115
  gr.Examples([["A 35 year-old male had severe headache after taking Aspirin. The lab results were normal."],
 
5
  import scipy as sp
6
  import torch
7
  import transformers
 
8
  from transformers import AutoTokenizer, AutoModelForSequenceClassification, AutoModelForTokenClassification
9
 
10
  import matplotlib.pyplot as plt
 
18
  tokenizer = AutoTokenizer.from_pretrained("jschwaller/ADRv2024")
19
  model = AutoModelForSequenceClassification.from_pretrained("jschwaller/ADRv2024")
20
 
21
+ pred = transformers.pipeline("text-classification", model=model, tokenizer=tokenizer, return_all_scores=True)
 
 
22
 
23
  explainer = shap.Explainer(pred)
24
 
25
  ner_tokenizer = AutoTokenizer.from_pretrained("d4data/biomedical-ner-all")
26
  ner_model = AutoModelForTokenClassification.from_pretrained("d4data/biomedical-ner-all")
27
 
28
+ ner_pipe = pipeline("ner", model=ner_model, tokenizer=ner_tokenizer, aggregation_strategy="simple")
 
 
 
 
 
 
29
 
30
  def adr_predict(x):
31
  encoded_input = tokenizer(x, return_tensors='pt')
32
  output = model(**encoded_input)
33
  scores = output[0][0].detach()
34
  scores = torch.nn.functional.softmax(scores)
35
+
 
36
  shap_values = explainer([str(x).lower()])
37
+ local_plot = shap.plots.text(shap_values[0], display=False)
 
 
 
38
 
39
  res = ner_pipe(x)
40
  entity_colors = {
41
+ 'Severity': '#E63946',
42
+ 'Sign_symptom': '#2A9D8F',
43
+ 'Medication': '#457B9D',
44
+ 'Age': '#F4A261',
45
+ 'Sex': '#F4A261',
46
+ 'Diagnostic_procedure': '#9C6644',
47
+ 'Biological_structure': '#BDB2FF',
48
  }
49
 
50
  htext = ""
 
71
 
72
  css = """
73
  body { font-family: 'Roboto', sans-serif; background-color: #333; color: #87CEEB; }
74
+ h1, h2, h3, h4, h5, h6, p, label, .markdown { color: #87CEEB; }
75
  .textbox { width: 100%; border-radius: 10px; border: 1px solid #ccc; background-color: white; color: black; }
76
  .button { background-color: #FF6347; color: white; border: none; border-radius: 10px; padding: 10px 20px; cursor: pointer; }
77
  """
 
89
  with gr.Column(visible=True):
90
  local_plot = gr.HTML(label='Shap:')
91
  htext = gr.HTML(label="NER")
92
+
93
+ legend = gr.HTML(value="<div style='margin-top: 20px;'><strong>Legend:</strong><br>" +
94
+ "<mark style='background-color:#E63946;'>Severity</mark> " +
95
+ "<mark style='background-color:#2A9D8F;'>Sign/Symptom</mark> " +
96
+ "<mark style='background-color:#457B9D;'>Medication</mark> " +
97
+ "<mark style='background-color:#F4A261;'>Age/Sex</mark> " +
98
+ "<mark style='background-color:#9C6644;'>Diagnostic Procedure</mark> " +
99
+ "<mark style='background-color:#BDB2FF;'>Biological Structure</mark></div>")
100
  submit_btn.click(
101
  main,
102
  [prob1],
 
104
  api_name="adr"
105
  )
106
 
107
+ gr.Row([legend])
108
+
109
  with gr.Row():
110
  gr.Markdown("### Click on any of the examples below to see how it works:")
111
  gr.Examples([["A 35 year-old male had severe headache after taking Aspirin. The lab results were normal."],