wjbmattingly commited on
Commit
1561fc5
·
verified ·
1 Parent(s): 19bbcb6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +65 -76
app.py CHANGED
@@ -1,7 +1,11 @@
1
  import gradio as gr
2
  from transformers import TrOCRProcessor, VisionEncoderDecoderModel
3
  import torch
4
- import spaces
 
 
 
 
5
 
6
  # Dictionary of model names and their corresponding HuggingFace model IDs
7
  MODEL_OPTIONS = {
@@ -32,93 +36,78 @@ def load_model(model_name):
32
  current_model = VisionEncoderDecoderModel.from_pretrained(model_id)
33
  current_model_name = model_name
34
 
35
- # Move model to GPU
36
- current_model = current_model.to('cuda')
 
37
 
38
  return current_processor, current_model
39
 
40
- @spaces.GPU
41
  def process_image(image, model_name):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
  processor, model = load_model(model_name)
43
-
44
- # Prepare image
45
- pixel_values = processor(image, return_tensors="pt").pixel_values
46
-
47
- # Move input to GPU
48
- pixel_values = pixel_values.to('cuda')
49
-
50
- # Generate (no beam search)
51
- with torch.no_grad():
52
- generated_ids = model.generate(pixel_values)
53
-
54
- # Decode
55
- generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
56
- return generated_text
57
-
58
- # Base URL for the images
59
- base_url = "https://huggingface.co/medieval-data/trocr-medieval-base/resolve/main/images/"
60
-
61
- # List of example images and their corresponding models
62
- examples = [
63
- [f"{base_url}caroline-1.png", "Medieval Latin Caroline"],
64
- [f"{base_url}caroline-2.png", "Medieval Latin Caroline"],
65
- [f"{base_url}cursiva-1.png", "Medieval Cursiva"],
66
- [f"{base_url}cursiva-2.png", "Medieval Cursiva"],
67
- [f"{base_url}cursiva-3.png", "Medieval Cursiva"],
68
- [f"{base_url}humanistica-1.png", "Medieval Humanistica"],
69
- [f"{base_url}humanistica-2.png", "Medieval Humanistica"],
70
- [f"{base_url}humanistica-3.png", "Medieval Humanistica"],
71
- [f"{base_url}hybrida-1.png", "Medieval Castilian Hybrida"],
72
- [f"{base_url}hybrida-2.png", "Medieval Castilian Hybrida"],
73
- [f"{base_url}hybrida-3.png", "Medieval Castilian Hybrida"],
74
- [f"{base_url}praegothica-1.png", "Medieval Praegothica"],
75
- [f"{base_url}praegothica-2.png", "Medieval Praegothica"],
76
- [f"{base_url}praegothica-3.png", "Medieval Praegothica"],
77
- [f"{base_url}print-1.png", "Medieval Print"],
78
- [f"{base_url}print-2.png", "Medieval Print"],
79
- [f"{base_url}print-3.png", "Medieval Print"],
80
- [f"{base_url}semihybrida-1.png", "Medieval Semihybrida"],
81
- [f"{base_url}semihybrida-2.png", "Medieval Semihybrida"],
82
- [f"{base_url}semihybrida-3.png", "Medieval Semihybrida"],
83
- [f"{base_url}semitextualis-1.png", "Medieval Semitextualis"],
84
- [f"{base_url}semitextualis-2.png", "Medieval Semitextualis"],
85
- [f"{base_url}semitextualis-3.png", "Medieval Semitextualis"],
86
- [f"{base_url}textualis-1.png", "Medieval Textualis"],
87
- [f"{base_url}textualis-2.png", "Medieval Textualis"],
88
- [f"{base_url}textualis-3.png", "Medieval Textualis"],
89
- ]
90
-
91
- # Custom CSS to make the image wider
92
- custom_css = """
93
- #image_upload {
94
- max-width: 100% !important;
95
- width: 100% !important;
96
- height: auto !important;
97
- }
98
- #image_upload > div:first-child {
99
- width: 100% !important;
100
- }
101
- #image_upload img {
102
- max-width: 100% !important;
103
- width: 100% !important;
104
- height: auto !important;
105
- }
106
- """
107
 
108
  # Gradio interface
109
- with gr.Blocks(css=custom_css) as iface:
110
- gr.Markdown("# Medieval TrOCR Model Switcher")
111
- gr.Markdown("Upload an image of medieval text and select a model to transcribe it. Note: This tool is designed to work on a single line of text at a time for optimal results.")
112
 
113
  with gr.Row():
114
- input_image = gr.Image(type="pil", label="Input Image", elem_id="image_upload")
115
  model_dropdown = gr.Dropdown(choices=list(MODEL_OPTIONS.keys()), label="Select Model", value="Medieval Base")
116
 
117
- transcription_output = gr.Textbox(label="Transcription")
 
 
118
 
119
  submit_button = gr.Button("Transcribe")
120
- submit_button.click(fn=process_image, inputs=[input_image, model_dropdown], outputs=transcription_output)
121
-
122
- gr.Examples(examples, inputs=[input_image, model_dropdown], outputs=transcription_output)
123
 
124
  iface.launch()
 
1
  import gradio as gr
2
  from transformers import TrOCRProcessor, VisionEncoderDecoderModel
3
  import torch
4
+ import subprocess
5
+ import json
6
+ from PIL import Image, ImageDraw
7
+ import os
8
+ import tempfile
9
 
10
  # Dictionary of model names and their corresponding HuggingFace model IDs
11
  MODEL_OPTIONS = {
 
36
  current_model = VisionEncoderDecoderModel.from_pretrained(model_id)
37
  current_model_name = model_name
38
 
39
+ # Move model to GPU if available
40
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
41
+ current_model = current_model.to(device)
42
 
43
  return current_processor, current_model
44
 
 
45
  def process_image(image, model_name):
46
+ # Save the uploaded image to a temporary file
47
+ with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as temp_img:
48
+ image.save(temp_img, format="JPEG")
49
+ temp_img_path = temp_img.name
50
+
51
+ # Run Kraken for line detection
52
+ lines_json_path = "lines.json"
53
+ kraken_command = f"kraken -i {temp_img_path} {lines_json_path} binarize segment -bl"
54
+ subprocess.run(kraken_command, shell=True, check=True)
55
+
56
+ # Load the lines from the JSON file
57
+ with open(lines_json_path, 'r') as f:
58
+ lines_data = json.load(f)
59
+
60
  processor, model = load_model(model_name)
61
+
62
+ # Process each line
63
+ transcriptions = []
64
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
65
+ for line in lines_data['lines']:
66
+ # Extract line coordinates
67
+ x1, y1 = line['baseline'][0]
68
+ x2, y2 = line['baseline'][-1]
69
+
70
+ # Crop the line from the original image
71
+ line_image = image.crop((x1, y1, x2, y2))
72
+
73
+ # Prepare image for TrOCR
74
+ pixel_values = processor(line_image, return_tensors="pt").pixel_values
75
+ pixel_values = pixel_values.to(device)
76
+
77
+ # Generate (no beam search)
78
+ with torch.no_grad():
79
+ generated_ids = model.generate(pixel_values)
80
+
81
+ # Decode
82
+ generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
83
+ transcriptions.append(generated_text)
84
+
85
+ # Clean up temporary files
86
+ os.unlink(temp_img_path)
87
+ os.unlink(lines_json_path)
88
+
89
+ # Create an image with bounding boxes
90
+ draw = ImageDraw.Draw(image)
91
+ for line in lines_data['lines']:
92
+ coords = line['baseline']
93
+ draw.line(coords, fill="red", width=2)
94
+
95
+ return image, "\n".join(transcriptions)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96
 
97
  # Gradio interface
98
+ with gr.Blocks() as iface:
99
+ gr.Markdown("# Medieval Document Transcription")
100
+ gr.Markdown("Upload an image of a medieval document and select a model to transcribe it. The tool will detect lines and transcribe each line separately.")
101
 
102
  with gr.Row():
103
+ input_image = gr.Image(type="pil", label="Input Image")
104
  model_dropdown = gr.Dropdown(choices=list(MODEL_OPTIONS.keys()), label="Select Model", value="Medieval Base")
105
 
106
+ with gr.Row():
107
+ output_image = gr.Image(type="pil", label="Detected Lines")
108
+ transcription_output = gr.Textbox(label="Transcription", lines=10)
109
 
110
  submit_button = gr.Button("Transcribe")
111
+ submit_button.click(fn=process_image, inputs=[input_image, model_dropdown], outputs=[output_image, transcription_output])
 
 
112
 
113
  iface.launch()