|
import os |
|
from io import BytesIO |
|
|
|
import gradio as gr |
|
import grpc |
|
from PIL import Image |
|
|
|
from inference_pb2 import OmniRequest, OmniResponse |
|
from inference_pb2_grpc import OmniServiceStub |
|
|
|
|
|
|
|
def get_bytes(img): |
|
if img is None: |
|
return img |
|
|
|
buffered = BytesIO() |
|
img.save(buffered, format="JPEG") |
|
return buffered.getvalue() |
|
|
|
|
|
def bytes_to_image(image: bytes) -> Image.Image: |
|
image = Image.open(BytesIO(image)) |
|
return image |
|
|
|
|
|
|
|
def generate_answer(question, image): |
|
image_bytes = get_bytes(image) |
|
|
|
if image_bytes is None: |
|
image_bytes = b'image' |
|
|
|
with grpc.insecure_channel(os.environ['SERVER']) as channel: |
|
stub = OmniServiceStub(channel) |
|
|
|
output: OmniResponse = stub.get_answer(OmniRequest(image=image_bytes, question=question)) |
|
output = output.answer |
|
|
|
return output |
|
|
|
|
|
def get_demo(): |
|
demo = gr.Interface( |
|
fn=generate_answer, |
|
inputs=["text", gr.Image(type="pil")], |
|
outputs=["text"], |
|
) |
|
return demo |
|
|
|
|
|
if __name__ == '__main__': |
|
|
|
demo = get_demo() |
|
demo.launch(server_name="0.0.0.0", server_port=7860, share=True) |
|
|