jschwaller commited on
Commit
6598a12
·
verified ·
1 Parent(s): 9f1b256

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +22 -19
app.py CHANGED
@@ -5,6 +5,7 @@ import numpy as np
5
  import scipy as sp
6
  import torch
7
  import transformers
 
8
  from transformers import AutoTokenizer, AutoModelForSequenceClassification, AutoModelForTokenClassification
9
 
10
  import matplotlib.pyplot as plt
@@ -18,14 +19,17 @@ device = "cuda:0" if torch.cuda.is_available() else "cpu"
18
  tokenizer = AutoTokenizer.from_pretrained("jschwaller/ADRv2024")
19
  model = AutoModelForSequenceClassification.from_pretrained("jschwaller/ADRv2024")
20
 
21
- pred = transformers.pipeline("text-classification", model=model, tokenizer=tokenizer, return_all_scores=True)
 
 
22
 
23
  explainer = shap.Explainer(pred)
24
 
25
  ner_tokenizer = AutoTokenizer.from_pretrained("d4data/biomedical-ner-all")
26
  ner_model = AutoModelForTokenClassification.from_pretrained("d4data/biomedical-ner-all")
27
 
28
- ner_pipe = pipeline("ner", model=ner_model, tokenizer=ner_tokenizer, aggregation_strategy="simple")
 
29
 
30
  def adr_predict(x):
31
  encoded_input = tokenizer(x, return_tensors='pt')
@@ -38,13 +42,13 @@ def adr_predict(x):
38
 
39
  res = ner_pipe(x)
40
  entity_colors = {
41
- 'Severity': '#E63946',
42
- 'Sign_symptom': '#2A9D8F',
43
- 'Medication': '#457B9D',
44
- 'Age': '#F4A261',
45
- 'Sex': '#F4A261',
46
- 'Diagnostic_procedure': '#9C6644',
47
- 'Biological_structure': '#BDB2FF',
48
  }
49
 
50
  htext = ""
@@ -71,7 +75,7 @@ description1 = "This app takes text (up to a few sentences) and predicts to what
71
 
72
  css = """
73
  body { font-family: 'Roboto', sans-serif; background-color: #333; color: #87CEEB; }
74
- h1, h2, h3, h4, h5, h6, p, label, .markdown { color: #87CEEB; }
75
  .textbox { width: 100%; border-radius: 10px; border: 1px solid #ccc; background-color: white; color: black; }
76
  .button { background-color: #FF6347; color: white; border: none; border-radius: 10px; padding: 10px 20px; cursor: pointer; }
77
  """
@@ -89,7 +93,15 @@ with gr.Blocks(css=css) as demo:
89
  with gr.Column(visible=True):
90
  local_plot = gr.HTML(label='Shap:')
91
  htext = gr.HTML(label="NER")
 
 
 
 
 
 
 
92
 
 
93
  legend = gr.HTML(value="<div style='margin-top: 20px;'><strong>Legend:</strong><br>" +
94
  "<mark style='background-color:#E63946;'>Severity</mark> " +
95
  "<mark style='background-color:#2A9D8F;'>Sign/Symptom</mark> " +
@@ -97,13 +109,6 @@ with gr.Blocks(css=css) as demo:
97
  "<mark style='background-color:#F4A261;'>Age/Sex</mark> " +
98
  "<mark style='background-color:#9C6644;'>Diagnostic Procedure</mark> " +
99
  "<mark style='background-color:#BDB2FF;'>Biological Structure</mark></div>")
100
- submit_btn.click(
101
- main,
102
- [prob1],
103
- [label, local_plot, htext],
104
- api_name="adr"
105
- )
106
-
107
  gr.Row([legend])
108
 
109
  with gr.Row():
@@ -111,5 +116,3 @@ with gr.Blocks(css=css) as demo:
111
  gr.Examples([["A 35 year-old male had severe headache after taking Aspirin. The lab results were normal."],
112
  ["A 35 year-old female had minor pain in upper abdomen after taking Acetaminophen."]],
113
  [prob1], [label, local_plot, htext], main, cache_examples=True)
114
-
115
- demo.launch()
 
5
  import scipy as sp
6
  import torch
7
  import transformers
8
+ from transformers import pipeline
9
  from transformers import AutoTokenizer, AutoModelForSequenceClassification, AutoModelForTokenClassification
10
 
11
  import matplotlib.pyplot as plt
 
19
  tokenizer = AutoTokenizer.from_pretrained("jschwaller/ADRv2024")
20
  model = AutoModelForSequenceClassification.from_pretrained("jschwaller/ADRv2024")
21
 
22
+ # Build a pipeline object for predictions
23
+ pred = transformers.pipeline("text-classification", model=model,
24
+ tokenizer=tokenizer, return_all_scores=True)
25
 
26
  explainer = shap.Explainer(pred)
27
 
28
  ner_tokenizer = AutoTokenizer.from_pretrained("d4data/biomedical-ner-all")
29
  ner_model = AutoModelForTokenClassification.from_pretrained("d4data/biomedical-ner-all")
30
 
31
+ ner_pipe = pipeline("ner", model=ner_model, tokenizer=ner_tokenizer, aggregation_strategy="simple") # pass device=0 if using gpu
32
+ #
33
 
34
  def adr_predict(x):
35
  encoded_input = tokenizer(x, return_tensors='pt')
 
42
 
43
  res = ner_pipe(x)
44
  entity_colors = {
45
+ 'Severity': '#E63946', # a vivid red
46
+ 'Sign_symptom': '#2A9D8F', # a deep teal
47
+ 'Medication': '#457B9D', # a dusky blue
48
+ 'Age': '#F4A261', # a sandy orange
49
+ 'Sex': '#F4A261', # same sandy orange for consistency with 'Age'
50
+ 'Diagnostic_procedure': '#9C6644', # a brown
51
+ 'Biological_structure': '#BDB2FF', # a light pastel purple
52
  }
53
 
54
  htext = ""
 
75
 
76
  css = """
77
  body { font-family: 'Roboto', sans-serif; background-color: #333; color: #87CEEB; }
78
+ h1, h2, h3, h4, h5, h6, p, label, .markdown { color: #87CEEB; } /* Ensuring that all text elements are consistently light blue */
79
  .textbox { width: 100%; border-radius: 10px; border: 1px solid #ccc; background-color: white; color: black; }
80
  .button { background-color: #FF6347; color: white; border: none; border-radius: 10px; padding: 10px 20px; cursor: pointer; }
81
  """
 
93
  with gr.Column(visible=True):
94
  local_plot = gr.HTML(label='Shap:')
95
  htext = gr.HTML(label="NER")
96
+
97
+ submit_btn.click(
98
+ main,
99
+ [prob1],
100
+ [label, local_plot, htext],
101
+ api_name="adr"
102
+ )
103
 
104
+ gr.Markdown("### Legend")
105
  legend = gr.HTML(value="<div style='margin-top: 20px;'><strong>Legend:</strong><br>" +
106
  "<mark style='background-color:#E63946;'>Severity</mark> " +
107
  "<mark style='background-color:#2A9D8F;'>Sign/Symptom</mark> " +
 
109
  "<mark style='background-color:#F4A261;'>Age/Sex</mark> " +
110
  "<mark style='background-color:#9C6644;'>Diagnostic Procedure</mark> " +
111
  "<mark style='background-color:#BDB2FF;'>Biological Structure</mark></div>")
 
 
 
 
 
 
 
112
  gr.Row([legend])
113
 
114
  with gr.Row():
 
116
  gr.Examples([["A 35 year-old male had severe headache after taking Aspirin. The lab results were normal."],
117
  ["A 35 year-old female had minor pain in upper abdomen after taking Acetaminophen."]],
118
  [prob1], [label, local_plot, htext], main, cache_examples=True)