Spaces:
Sleeping
Sleeping
File size: 5,077 Bytes
4e8500c 0c1fe20 dabac75 4e8500c 0c1fe20 4e8500c 0c1fe20 dabac75 0c1fe20 4e8500c dabac75 4e8500c dabac75 a631f2b 0c1fe20 dabac75 b4fde32 dabac75 a631f2b b4fde32 dabac75 b4fde32 dabac75 4e8500c 0c1fe20 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 |
import gradio as gr
from transformers import TrOCRProcessor, VisionEncoderDecoderModel
import torch
import spaces
# Dictionary of model names and their corresponding HuggingFace model IDs
MODEL_OPTIONS = {
"Microsoft Handwritten": "microsoft/trocr-base-handwritten",
"Medieval Base": "medieval-data/trocr-medieval-base",
"Medieval Latin Caroline": "medieval-data/trocr-medieval-latin-caroline",
"Medieval Castilian Hybrida": "medieval-data/trocr-medieval-castilian-hybrida",
"Medieval Humanistica": "medieval-data/trocr-medieval-humanistica",
"Medieval Textualis": "medieval-data/trocr-medieval-textualis",
"Medieval Cursiva": "medieval-data/trocr-medieval-cursiva",
"Medieval Semitextualis": "medieval-data/trocr-medieval-semitextualis",
"Medieval Praegothica": "medieval-data/trocr-medieval-praegothica",
"Medieval Semihybrida": "medieval-data/trocr-medieval-semihybrida",
"Medieval Print": "medieval-data/trocr-medieval-print"
}
# Global variables to store the current model and processor
current_model = None
current_processor = None
current_model_name = None
def load_model(model_name):
global current_model, current_processor, current_model_name
if model_name != current_model_name:
model_id = MODEL_OPTIONS[model_name]
current_processor = TrOCRProcessor.from_pretrained(model_id)
current_model = VisionEncoderDecoderModel.from_pretrained(model_id)
current_model_name = model_name
# Move model to GPU
current_model = current_model.to('cuda')
return current_processor, current_model
@spaces.GPU
def process_image(image, model_name):
processor, model = load_model(model_name)
# Prepare image
pixel_values = processor(image, return_tensors="pt").pixel_values
# Move input to GPU
pixel_values = pixel_values.to('cuda')
# Generate (no beam search)
with torch.no_grad():
generated_ids = model.generate(pixel_values)
# Decode
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
return generated_text
# Base URL for the images
base_url = "https://huggingface.co/medieval-data/trocr-medieval-base/resolve/main/images/"
# List of example images and their corresponding models
examples = [
[f"{base_url}caroline-1.png", "Medieval Latin Caroline"],
[f"{base_url}caroline-2.png", "Medieval Latin Caroline"],
[f"{base_url}cursiva-1.png", "Medieval Cursiva"],
[f"{base_url}cursiva-2.png", "Medieval Cursiva"],
[f"{base_url}cursiva-3.png", "Medieval Cursiva"],
[f"{base_url}humanistica-1.png", "Medieval Humanistica"],
[f"{base_url}humanistica-2.png", "Medieval Humanistica"],
[f"{base_url}humanistica-3.png", "Medieval Humanistica"],
[f"{base_url}hybrida-1.png", "Medieval Castilian Hybrida"],
[f"{base_url}hybrida-2.png", "Medieval Castilian Hybrida"],
[f"{base_url}hybrida-3.png", "Medieval Castilian Hybrida"],
[f"{base_url}praegothica-1.png", "Medieval Praegothica"],
[f"{base_url}praegothica-2.png", "Medieval Praegothica"],
[f"{base_url}praegothica-3.png", "Medieval Praegothica"],
[f"{base_url}print-1.png", "Medieval Print"],
[f"{base_url}print-2.png", "Medieval Print"],
[f"{base_url}print-3.png", "Medieval Print"],
[f"{base_url}semihybrida-1.png", "Medieval Semihybrida"],
[f"{base_url}semihybrida-2.png", "Medieval Semihybrida"],
[f"{base_url}semihybrida-3.png", "Medieval Semihybrida"],
[f"{base_url}semitextualis-1.png", "Medieval Semitextualis"],
[f"{base_url}semitextualis-2.png", "Medieval Semitextualis"],
[f"{base_url}semitextualis-3.png", "Medieval Semitextualis"],
[f"{base_url}textualis-1.png", "Medieval Textualis"],
[f"{base_url}textualis-2.png", "Medieval Textualis"],
[f"{base_url}textualis-3.png", "Medieval Textualis"],
]
# Custom CSS to make the image wider
custom_css = """
#image_upload {
max-width: 100% !important;
width: 100% !important;
height: auto !important;
}
#image_upload > div:first-child {
width: 100% !important;
}
#image_upload img {
max-width: 100% !important;
width: 100% !important;
height: auto !important;
}
"""
# Gradio interface
with gr.Blocks(css=custom_css) as iface:
gr.Markdown("# Medieval TrOCR Model Switcher")
gr.Markdown("Upload an image of medieval text and select a model to transcribe it. Note: This tool is designed to work on a single line of text at a time for optimal results.")
with gr.Row():
input_image = gr.Image(type="pil", label="Input Image", elem_id="image_upload")
model_dropdown = gr.Dropdown(choices=list(MODEL_OPTIONS.keys()), label="Select Model", value="Medieval Base")
transcription_output = gr.Textbox(label="Transcription")
submit_button = gr.Button("Transcribe")
submit_button.click(fn=process_image, inputs=[input_image, model_dropdown], outputs=transcription_output)
gr.Examples(examples, inputs=[input_image, model_dropdown], outputs=transcription_output)
iface.launch() |