File size: 7,235 Bytes
9a1ab03
de53991
43d8095
9a1ab03
44d180e
43d8095
44d180e
f9f4138
 
 
9779cd8
 
9a1ab03
 
5e8ccd5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5275a8d
57ace24
01a2ce5
43d8095
9a1ab03
 
 
 
 
 
 
 
 
 
 
 
 
43d8095
 
9a1ab03
 
26b45a8
1921336
 
 
 
 
 
 
 
 
 
 
 
 
26b45a8
 
1921336
26b45a8
f9f4138
 
5e8ccd5
26b45a8
a232b31
5e8ccd5
66846f0
44d180e
9a1ab03
5e8ccd5
de53991
68c64e4
dea4ce7
6c0544b
9a1ab03
de53991
 
 
43d8095
c5b9462
54c6336
de53991
 
 
5e8ccd5
b97cda3
1921336
de53991
5a4b599
143b62d
de53991
5e8ccd5
f253a0d
54c6336
 
29c23fe
c52847e
 
2ffcfd9
 
c52847e
2ffcfd9
 
 
5f75abd
de53991
 
5e8ccd5
de53991
 
 
 
 
43d8095
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
from dotenv import load_dotenv
import gradio as gr
import random

from utils.model import Model
from utils.data import dataset

import gc
import torch

import logging

load_dotenv()

custom_css = """
gradio-app {
    background: #eeeefc !important;
}
.bordered-text {
    border-style: solid;
    border-width: 1px;
    padding: 5px;
    margin-bottom: 0px;
    border-radius: 1px;
    font-family: Verdana;
    font-size: 20px !important;
    font-weight: bold ;
    color:#000000;
}
.parameter-text {
    border-style: solid;
    border-width: 1px;
    padding: 5px;
    margin-bottom: 0px;
    border-radius: 1px;
    font-family: Verdana;
    font-size: 10px !important;
    font-weight: bold ;
    color:#000000;
}
.title {
    font-size: 35px;
    font-weight: maroon;
    font-family: Helvetica;
}
input-label {
    font-size: 20px;
    font-weight: bold;
    font-family: Papyrus;
}
.custom-button {
    background-color: white !important /* Green background */
    color: black; /* White text */
    border: none; /* Remove border */
    padding: 10px 20px; /* Add padding */
    text-align: center; /* Center text */
    display: inline-block; /* Inline block */
    font-size: 22px; /* Font size */
    margin: 4px 2px; /* Margin */
    cursor: pointer; /* Pointer cursor on hover */
    border-radius: 4px; /* Rounded corners */
}
.custom-button:hover {
    background-color: black;
    color: white;
}
"""

__model_on_gpu__ = ''
model = {model_name: None for model_name in Model.__model_list__}

random_label = '🔀 Random dialogue from dataset'
examples = {
    "example 1": """Boston's injury reporting for Kristaps Porziņģis has been fairly coy. He missed Game 3, but his coach told reporters just before Game 4 that was technically available, but with a catch.
Joe Mazzulla said Porziņģis would "only be used in specific instances, if necessary." That sounds like the team doesn't want to risk further injury to his dislocated Posterior Tibialis (or some other body part, due to overcompensation for the ankle), unless it's in a desperate situation.
Being up 3-1, with Game 5 at home, doesn't qualify as desperate. So, expect the Celtics to continue slow-playing KP's return.
It'd obviously be nice for Boston to have his rim protection and jump shooting back. It was missed in the Game 4 blowout, but the Celtics have also demonstrated they can win without the big man throughout this campaign.
On top of winning Game 3 of this series, Boston is plus-10.9 points per 100 possessions when Porziņģis has been off the floor this regular and postseason.""",
    
    "example 2": """Prior to the Finals, we predicted that Dereck Lively II's minutes would swell over the course of the series, and that's starting to play out.
He averaged 18.8 minutes in Games 1 and 2 and was up to 26.2 in Games 3 and 4. That's with the regulars being pulled long before the final buzzer in Friday's game, too.
Expect the rookie's playing time to continue to climb in Game 5. It seems increasingly clear that coach Jason Kidd trusts him over the rest of Dallas' bigs, and it's not hard to see why.
Lively has been absolutely relentless on the offensive glass all postseason. He makes solid decisions as a passer when his rolls don't immediately lead to dunks. And he's not a liability when caught defending guards or wings outside.
All of that has led to postseason averages of 8.2 points, 7.6 rebounds, 1.4 assists and 1.0 blocks in just 21.9 minutes, as well as a double-double in 22 minutes of Game 4.
Back in Boston, Kidd is going to rely on Lively even more. He'll play close to 30 minutes and reach double-figures in both scoring and rebounding again.""",

    random_label: ""
}

def model_device_check(model_name):
    global __model_on_gpu__

    if __model_on_gpu__ != model_name:
        if __model_on_gpu__:
            logging.info(f"delete model {__model_on_gpu__}")
            del model[__model_on_gpu__]
            gc.collect()
            torch.cuda.empty_cache()

        model[model_name] = Model(model_name)
        __model_on_gpu__ = model_name


def get_model_batch_generation(model_name):
    model_device_check(model_name)

    return model[model_name]


def generate_answer(sources, model_name, prompt, temperature, max_new_tokens, do_sample):
    model_device_check(model_name)
    content = prompt + '\n{' + sources + '}\n\nsummary:'
    answer = model[model_name].gen(content,temperature,max_new_tokens,do_sample)[0].strip()

    return answer

def process_input(input_text, model_selection, prompt, temperature, max_new_tokens, do_sample):
    if input_text:
        logging.info("Start generation")
        response = generate_answer(input_text, model_selection, prompt)
        return f"## Original Dialogue:\n\n{input_text}\n\n## Summarization:\n\n{response}"
    else:
        return "Please fill the input to generate outputs."

def update_input(example):
    if example == random_label:
        datapoint = random.choice(dataset)
        return datapoint['section_text'] + '\n\nDialogue:\n' + datapoint['dialogue']
    return examples[example]

def create_summarization_interface():
    with gr.Blocks(theme=gr.themes.Soft(spacing_size="sm",text_size="sm"), css=custom_css) as demo:
        gr.Markdown("## This is a playground to test prompts for clinical dialogue summarizations")

        with gr.Row():
            example_dropdown = gr.Dropdown(choices=list(examples.keys()), label="Choose an example", value=random_label)
            model_dropdown = gr.Dropdown(choices=Model.__model_list__, label="Choose a model", value=Model.__model_list__[0])
        
        gr.Markdown("<div style='border: 4px solid white; padding: 3px; border-radius: 5px;width:100px;padding-top: 0.5px;padding-bottom: 10px;'><h3>Prompt 👥</h3></center></div>")
        Template_text = gr.Textbox(value="""Summarize the following dialogue""", label='Input Prompting Template', lines=8, placeholder='Input your prompts')
        datapoint = random.choice(dataset)
        input_text = gr.Textbox(label="Input Dialogue", lines=10, placeholder="Enter text here...", value=datapoint['section_text'] + '\n\nDialogue:\n' + datapoint['dialogue'])
        submit_button = gr.Button("✨ Submit ✨")

        with gr.Row():
            with gr.Column(scale=1):
                gr.Markdown("<div style='border: 4px solid white; padding: 2px; border-radius: 5px;width:130px;padding-bottom: 10px;'><b><h3>Parameters 📈</h3></center></b></div>")
                with gr.Column():
                    temperature = gr.Textbox(label="Temperature",elem_classes="parameter-text", value=0.0)
                    max_new_tokens = gr.Textbox(label="Max New Tokens",elem_classes="parameter-text", value=500)
                    do_sample = gr.Dropdown([True,False],label="Do Sample",elem_classes="parameter-text", value=True)
            output = gr.Markdown(line_breaks=True)

        example_dropdown.change(update_input, inputs=[example_dropdown], outputs=[input_text])
        submit_button.click(process_input, inputs=[input_text,model_dropdown,Template_text,temperature,max_new_tokens,do_sample], outputs=[output])

    return demo

if __name__ == "__main__":
    demo = create_summarization_interface()
    demo.launch()