Addaci commited on
Commit
9519c42
1 Parent(s): 325d895

Total rewrite of app.py (Moved to Inference Client approach)

Browse files
Files changed (1) hide show
  1. app.py +23 -55
app.py CHANGED
@@ -1,78 +1,47 @@
1
- import os
 
2
  import gradio as gr
 
3
  import logging
4
- from transformers import MT5Tokenizer, MT5ForConditionalGeneration
5
 
6
  # Setup logging
7
  logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s')
8
 
9
- # Load your fine-tuned mT5 model
10
- model_name = "Addaci/mT5-small-experiment-13-checkpoint-2790"
11
- tokenizer = MT5Tokenizer.from_pretrained(model_name)
12
- model = MT5ForConditionalGeneration.from_pretrained(model_name)
13
 
14
  def correct_htr(raw_htr_text):
15
  try:
16
- logging.info("Processing HTR correction...")
17
- inputs = tokenizer("correct this text: " + raw_htr_text, return_tensors="pt", max_length=512, truncation=True)
18
- logging.debug(f"Tokenized Inputs for HTR Correction: {inputs}")
19
-
20
- # Generate with beam search and sampling
21
- outputs = model.generate(**inputs, max_length=128, num_beams=4, early_stopping=True, temperature=0.6, do_sample=True)
22
- logging.debug(f"Generated Output (Tokens) for HTR Correction: {outputs}")
23
-
24
- corrected_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
25
- logging.debug(f"Decoded Output for HTR Correction: {corrected_text}")
26
-
27
- # Re-tokenize the output for further inspection
28
- logging.debug(f"Re-tokenized output for HTR Correction: {tokenizer.decode(outputs[0], skip_special_tokens=False)}")
29
-
30
- return corrected_text
31
  except Exception as e:
32
- logging.error(f"Error in HTR Correction: {e}", exc_info=True)
33
  return str(e)
34
 
35
  def summarize_text(legal_text):
36
  try:
37
- logging.info("Processing summarization...")
38
- inputs = tokenizer("summarize the following legal text: " + legal_text, return_tensors="pt", max_length=512, truncation=True)
39
- logging.debug(f"Tokenized Inputs for Summarization: {inputs}")
40
-
41
- # Generate with beam search and sampling
42
- outputs = model.generate(**inputs, max_length=150, num_beams=4, early_stopping=True, temperature=0.8, do_sample=True)
43
- logging.debug(f"Generated Summary (Tokens): {outputs}")
44
-
45
- summary = tokenizer.decode(outputs[0], skip_special_tokens=True)
46
- logging.debug(f"Decoded Summary: {summary}")
47
-
48
- # Re-tokenize the output for further inspection
49
- logging.debug(f"Re-tokenized output for Summarization: {tokenizer.decode(outputs[0], skip_special_tokens=False)}")
50
-
51
- return summary
52
  except Exception as e:
53
- logging.error(f"Error in Summarization: {e}", exc_info=True)
54
  return str(e)
55
 
56
  def answer_question(legal_text, question):
57
  try:
58
- logging.info("Processing question-answering...")
 
59
  formatted_input = f"Answer the following question based on the provided context:\n\nQuestion: {question}\n\nContext: {legal_text}"
60
- inputs = tokenizer(formatted_input, return_tensors="pt", max_length=512, truncation=True)
61
- logging.debug(f"Tokenized Inputs for Question Answering: {inputs}")
62
-
63
- # Generate with beam search and sampling
64
- outputs = model.generate(**inputs, max_length=150, num_beams=4, early_stopping=True, temperature=0.7, do_sample=True)
65
- logging.debug(f"Generated Answer (Tokens): {outputs}")
66
-
67
- answer = tokenizer.decode(outputs[0], skip_special_tokens=True)
68
- logging.debug(f"Decoded Answer: {answer}")
69
-
70
- # Re-tokenize the output for further inspection
71
- logging.debug(f"Re-tokenized output for Question Answering: {tokenizer.decode(outputs[0], skip_special_tokens=False)}")
72
-
73
- return answer
74
  except Exception as e:
75
- logging.error(f"Error in Question Answering: {e}", exc_info=True)
76
  return str(e)
77
 
78
  # Create the Gradio Blocks interface
@@ -80,7 +49,6 @@ with gr.Blocks() as demo:
80
  gr.Markdown("# mT5 Legal Assistant")
81
  gr.Markdown("Use this tool to correct raw HTR, summarize legal texts, or answer questions about legal cases.")
82
 
83
- # Add the two clickable buttons with separate boxes and bold text
84
  with gr.Row():
85
  gr.HTML('''
86
  <div style="display: flex; gap: 10px;">
 
1
+ # Cell 1B: Inference Client
2
+
3
  import gradio as gr
4
+ from huggingface_hub import InferenceClient
5
  import logging
 
6
 
7
  # Setup logging
8
  logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s')
9
 
10
+ # Initialize Inference Client
11
+ client = InferenceClient(model="Addaci/mT5-small-experiment-13-checkpoint-2790")
 
 
12
 
13
  def correct_htr(raw_htr_text):
14
  try:
15
+ logging.info("Processing HTR correction with InferenceClient...")
16
+ # Sending the input to the hosted model
17
+ result = client.text_generation(f"correct this text: {raw_htr_text}")
18
+ logging.debug(f"Generated output for HTR correction: {result}")
19
+ return result['generated_text'] # Extracting the generated text from the response
 
 
 
 
 
 
 
 
 
 
20
  except Exception as e:
21
+ logging.error(f"Error in HTR correction: {e}", exc_info=True)
22
  return str(e)
23
 
24
  def summarize_text(legal_text):
25
  try:
26
+ logging.info("Processing summarization with InferenceClient...")
27
+ # Sending the input to the hosted model
28
+ result = client.text_generation(f"summarize the following legal text: {legal_text}")
29
+ logging.debug(f"Generated summary: {result}")
30
+ return result['generated_text'] # Extracting the generated text from the response
 
 
 
 
 
 
 
 
 
 
31
  except Exception as e:
32
+ logging.error(f"Error in summarization: {e}", exc_info=True)
33
  return str(e)
34
 
35
  def answer_question(legal_text, question):
36
  try:
37
+ logging.info("Processing question-answering with InferenceClient...")
38
+ # Sending the input to the hosted model
39
  formatted_input = f"Answer the following question based on the provided context:\n\nQuestion: {question}\n\nContext: {legal_text}"
40
+ result = client.text_generation(formatted_input)
41
+ logging.debug(f"Generated answer: {result}")
42
+ return result['generated_text'] # Extracting the generated text from the response
 
 
 
 
 
 
 
 
 
 
 
43
  except Exception as e:
44
+ logging.error(f"Error in question-answering: {e}", exc_info=True)
45
  return str(e)
46
 
47
  # Create the Gradio Blocks interface
 
49
  gr.Markdown("# mT5 Legal Assistant")
50
  gr.Markdown("Use this tool to correct raw HTR, summarize legal texts, or answer questions about legal cases.")
51
 
 
52
  with gr.Row():
53
  gr.HTML('''
54
  <div style="display: flex; gap: 10px;">