File size: 5,077 Bytes
4e8500c
 
0c1fe20
dabac75
4e8500c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0c1fe20
 
 
 
4e8500c
 
0c1fe20
 
 
 
 
 
 
 
dabac75
 
0c1fe20
 
4e8500c
dabac75
4e8500c
 
dabac75
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a631f2b
0c1fe20
dabac75
 
 
b4fde32
 
dabac75
a631f2b
b4fde32
dabac75
b4fde32
 
dabac75
 
 
4e8500c
0c1fe20
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import gradio as gr
from transformers import TrOCRProcessor, VisionEncoderDecoderModel
import torch
import spaces

# Dictionary of model names and their corresponding HuggingFace model IDs
MODEL_OPTIONS = {
    "Microsoft Handwritten": "microsoft/trocr-base-handwritten",
    "Medieval Base": "medieval-data/trocr-medieval-base",
    "Medieval Latin Caroline": "medieval-data/trocr-medieval-latin-caroline",
    "Medieval Castilian Hybrida": "medieval-data/trocr-medieval-castilian-hybrida",
    "Medieval Humanistica": "medieval-data/trocr-medieval-humanistica",
    "Medieval Textualis": "medieval-data/trocr-medieval-textualis",
    "Medieval Cursiva": "medieval-data/trocr-medieval-cursiva",
    "Medieval Semitextualis": "medieval-data/trocr-medieval-semitextualis",
    "Medieval Praegothica": "medieval-data/trocr-medieval-praegothica",
    "Medieval Semihybrida": "medieval-data/trocr-medieval-semihybrida",
    "Medieval Print": "medieval-data/trocr-medieval-print"
}

# Global variables to store the current model and processor
current_model = None
current_processor = None
current_model_name = None

def load_model(model_name):
    global current_model, current_processor, current_model_name
    
    if model_name != current_model_name:
        model_id = MODEL_OPTIONS[model_name]
        current_processor = TrOCRProcessor.from_pretrained(model_id)
        current_model = VisionEncoderDecoderModel.from_pretrained(model_id)
        current_model_name = model_name
        
        # Move model to GPU
        current_model = current_model.to('cuda')
    
    return current_processor, current_model

@spaces.GPU
def process_image(image, model_name):
    processor, model = load_model(model_name)
    
    # Prepare image
    pixel_values = processor(image, return_tensors="pt").pixel_values
    
    # Move input to GPU
    pixel_values = pixel_values.to('cuda')
    
    # Generate (no beam search)
    with torch.no_grad():
        generated_ids = model.generate(pixel_values)
    
    # Decode
    generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
    return generated_text

# Base URL for the images
base_url = "https://huggingface.co/medieval-data/trocr-medieval-base/resolve/main/images/"

# List of example images and their corresponding models
examples = [
    [f"{base_url}caroline-1.png", "Medieval Latin Caroline"],
    [f"{base_url}caroline-2.png", "Medieval Latin Caroline"],
    [f"{base_url}cursiva-1.png", "Medieval Cursiva"],
    [f"{base_url}cursiva-2.png", "Medieval Cursiva"],
    [f"{base_url}cursiva-3.png", "Medieval Cursiva"],
    [f"{base_url}humanistica-1.png", "Medieval Humanistica"],
    [f"{base_url}humanistica-2.png", "Medieval Humanistica"],
    [f"{base_url}humanistica-3.png", "Medieval Humanistica"],
    [f"{base_url}hybrida-1.png", "Medieval Castilian Hybrida"],
    [f"{base_url}hybrida-2.png", "Medieval Castilian Hybrida"],
    [f"{base_url}hybrida-3.png", "Medieval Castilian Hybrida"],
    [f"{base_url}praegothica-1.png", "Medieval Praegothica"],
    [f"{base_url}praegothica-2.png", "Medieval Praegothica"],
    [f"{base_url}praegothica-3.png", "Medieval Praegothica"],
    [f"{base_url}print-1.png", "Medieval Print"],
    [f"{base_url}print-2.png", "Medieval Print"],
    [f"{base_url}print-3.png", "Medieval Print"],
    [f"{base_url}semihybrida-1.png", "Medieval Semihybrida"],
    [f"{base_url}semihybrida-2.png", "Medieval Semihybrida"],
    [f"{base_url}semihybrida-3.png", "Medieval Semihybrida"],
    [f"{base_url}semitextualis-1.png", "Medieval Semitextualis"],
    [f"{base_url}semitextualis-2.png", "Medieval Semitextualis"],
    [f"{base_url}semitextualis-3.png", "Medieval Semitextualis"],
    [f"{base_url}textualis-1.png", "Medieval Textualis"],
    [f"{base_url}textualis-2.png", "Medieval Textualis"],
    [f"{base_url}textualis-3.png", "Medieval Textualis"],
]

# Custom CSS to make the image wider
custom_css = """
#image_upload {
    max-width: 100% !important;
    width: 100% !important;
    height: auto !important;
}
#image_upload > div:first-child {
    width: 100% !important;
}
#image_upload img {
    max-width: 100% !important;
    width: 100% !important;
    height: auto !important;
}
"""

# Gradio interface
with gr.Blocks(css=custom_css) as iface:
    gr.Markdown("# Medieval TrOCR Model Switcher")
    gr.Markdown("Upload an image of medieval text and select a model to transcribe it. Note: This tool is designed to work on a single line of text at a time for optimal results.")
    
    with gr.Row():
        input_image = gr.Image(type="pil", label="Input Image", elem_id="image_upload")
        model_dropdown = gr.Dropdown(choices=list(MODEL_OPTIONS.keys()), label="Select Model", value="Medieval Base")
    
    transcription_output = gr.Textbox(label="Transcription")
    
    submit_button = gr.Button("Transcribe")
    submit_button.click(fn=process_image, inputs=[input_image, model_dropdown], outputs=transcription_output)
    
    gr.Examples(examples, inputs=[input_image, model_dropdown], outputs=transcription_output)

iface.launch()