File size: 5,259 Bytes
333cd91
74f4d51
333cd91
74f4d51
 
 
0eb38a5
74f4d51
 
 
 
df63252
74f4d51
 
 
df63252
 
333cd91
74f4d51
 
df63252
 
74f4d51
 
 
df63252
 
333cd91
74f4d51
 
df63252
 
74f4d51
 
 
df63252
 
333cd91
74f4d51
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93d2618
74f4d51
 
df63252
 
74f4d51
 
 
 
 
 
 
 
 
 
 
1b938f5
df63252
74f4d51
 
df63252
 
74f4d51
 
 
 
 
 
 
77122ee
 
74f4d51
 
 
 
 
 
 
 
 
 
 
 
 
 
340628c
1b938f5
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import gradio as gr
from transformers import pipeline, T5Tokenizer, T5ForConditionalGeneration

# Load model and tokenizer for mT5-small
model = T5ForConditionalGeneration.from_pretrained("google/mt5-small")
tokenizer = T5Tokenizer.from_pretrained("google/mt5-small")

# Define task-specific prompts
def correct_htr_text(input_text, max_new_tokens, temperature):
    prompt = f"Correct the following handwritten transcription for obvious errors while preserving C17th spelling: {input_text}"
    inputs = tokenizer(prompt, return_tensors="pt")
    outputs = model.generate(
        inputs.input_ids, 
        max_new_tokens=max_new_tokens, 
        temperature=temperature
    )
    return tokenizer.decode(outputs[0], skip_special_tokens=True)

def summarize_legal_text(input_text, max_new_tokens, temperature):
    prompt = f"Summarize this legal text: {input_text}"
    inputs = tokenizer(prompt, return_tensors="pt")
    outputs = model.generate(
        inputs.input_ids, 
        max_new_tokens=max_new_tokens, 
        temperature=temperature
    )
    return tokenizer.decode(outputs[0], skip_special_tokens=True)

def answer_legal_question(input_text, question, max_new_tokens, temperature):
    prompt = f"Answer this question based on the legal text: '{question}' Text: {input_text}"
    inputs = tokenizer(prompt, return_tensors="pt")
    outputs = model.generate(
        inputs.input_ids, 
        max_new_tokens=max_new_tokens, 
        temperature=temperature
    )
    return tokenizer.decode(outputs[0], skip_special_tokens=True)

# Define Gradio interface functions
def correct_htr_interface(text, max_new_tokens, temperature):
    return correct_htr_text(text, max_new_tokens, temperature)

def summarize_interface(text, max_new_tokens, temperature):
    return summarize_legal_text(text, max_new_tokens, temperature)

def question_interface(text, question, max_new_tokens, temperature):
    return answer_legal_question(text, question, max_new_tokens, temperature)

def clear_all():
    return "", ""

# External clickable buttons
def clickable_buttons():
    button_html = """
    <div style="display: flex; justify-content: space-between; margin-bottom: 10px;">
        <a href="http://www.marinelives.org/wiki/Tools:_Admiralty_court_legal_glossary" 
        style="border: 1px solid black; padding: 5px; text-align: center; width: 48%; background-color: #f0f0f0;">
        Admiralty Court Legal Glossary</a>
        <a href="https://github.com/Addaci/HCA/blob/main/HCA_13_70_Full_Volume_Processed_Text_EDITED_Ver.1.2_18062024.txt" 
        style="border: 1px solid black; padding: 5px; text-align: center; width: 48%; background-color: #f0f0f0;">
        HCA 13/70 Ground Truth</a>
    </div>
    """
    return button_html

# Interface layout
with gr.Blocks() as demo:
    gr.HTML("<h1>Flan-T5 Legal Assistant</h1>")
    gr.HTML(clickable_buttons())
    
    with gr.Tab("Correct Raw HTR"):
        input_text = gr.Textbox(lines=10, label="Textbox")
        output_text = gr.Textbox(label="Textbox")
        max_new_tokens = gr.Slider(10, 512, value=128, label="Max New Tokens")
        temperature = gr.Slider(0.1, 1.0, value=0.7, label="Temperature")
        correct_button = gr.Button("Correct HTR")
        clear_button = gr.Button("Clear")
        
        correct_button.click(fn=correct_htr_interface, 
                             inputs=[input_text, max_new_tokens, temperature], 
                             outputs=output_text)
        clear_button.click(fn=clear_all, outputs=[input_text, output_text])

    with gr.Tab("Summarize Legal Text"):
        input_text_summarize = gr.Textbox(lines=10, label="Textbox")
        output_text_summarize = gr.Textbox(label="Textbox")
        max_new_tokens_summarize = gr.Slider(10, 512, value=256, label="Max New Tokens")
        temperature_summarize = gr.Slider(0.1, 1.0, value=0.5, label="Temperature")
        summarize_button = gr.Button("Summarize Text")
        clear_button_summarize = gr.Button("Clear")
        
        summarize_button.click(fn=summarize_interface, 
                               inputs=[input_text_summarize, max_new_tokens_summarize, temperature_summarize], 
                               outputs=output_text_summarize)
        clear_button_summarize.click(fn=clear_all, outputs=[input_text_summarize, output_text_summarize])

    with gr.Tab("Answer Legal Question"):
        input_text_question = gr.Textbox(lines=10, label="Textbox")
        question = gr.Textbox(label="Textbox")
        output_text_question = gr.Textbox(label="Textbox")
        max_new_tokens_question = gr.Slider(10, 512, value=128, label="Max New Tokens")
        temperature_question = gr.Slider(0.1, 1.0, value=0.7, label="Temperature")
        question_button = gr.Button("Get Answer")
        clear_button_question = gr.Button("Clear")
        
        question_button.click(fn=question_interface, 
                              inputs=[input_text_question, question, max_new_tokens_question, temperature_question], 
                              outputs=output_text_question)
        clear_button_question.click(fn=clear_all, outputs=[input_text_question, question, output_text_question])

    gr.Button("Clear", elem_id="clear_button").click(clear_all)

demo.launch()