import gradio as gr import io import os import yaml import pyarrow import tokenizers from retro_reader import RetroReader os.environ["TOKENIZERS_PARALLELISM"] = "true" def from_library(): from retro_reader import constants as C return C, RetroReader C, RetroReader = from_library() # Assuming RetroReader.load is a method from your imports def load_model(config_path): return RetroReader.load(config_file=config_path) # Loading models model_base = load_model("configs/inference_en_electra_base.yaml") model_large = load_model("configs/inference_en_electra_large.yaml") def retro_reader_demo(query, context, model_choice): model = model_base if model_choice == "Base" else model_large outputs = model(query=query, context=context, return_submodule_outputs=True) answer = outputs[0]["id-01"] if outputs[0]["id-01"] else "No answer found" return answer # Gradio app interface iface = gr.Interface( fn=retro_reader_demo, inputs=[ gr.Textbox(label="Query", placeholder="Type your query here..."), gr.Textbox(label="Context", placeholder="Provide the context here...", lines=10), gr.Radio(choices=["Base", "Large"], label="Model Choice") ], outputs=gr.Textbox(label="Answer"), title="Retrospective Reader Demo", description="This interface uses the RetroReader model to perform reading comprehension tasks." ) if __name__ == "__main__": iface.launch(share=True)