|
import streamlit as st |
|
from PIL import Image |
|
import inference |
|
from transformers import AutoProcessor, AutoModelForCausalLM |
|
from PIL import Image |
|
import io |
|
import requests |
|
import copy |
|
import os |
|
from unittest.mock import patch |
|
from transformers.dynamic_module_utils import get_imports |
|
import torch |
|
|
|
|
|
def fixed_get_imports(filename: str | os.PathLike) -> list[str]: |
|
if not str(filename).endswith("modeling_florence2.py"): |
|
return get_imports(filename) |
|
imports = get_imports(filename) |
|
imports.remove("flash_attn") |
|
return imports |
|
|
|
|
|
if 'model_loaded' not in st.session_state: |
|
st.session_state.model_loaded = False |
|
|
|
|
|
def load_model(): |
|
|
|
model_id = "microsoft/Florence-2-large" |
|
|
|
st.session_state.processor = AutoProcessor.from_pretrained(model_id, torch_dtype=torch.qint8, trust_remote_code=True) |
|
try: |
|
os.mkdir("temp") |
|
except: |
|
pass |
|
|
|
|
|
with patch("transformers.dynamic_module_utils.get_imports", fixed_get_imports): |
|
model = AutoModelForCausalLM.from_pretrained(model_id, attn_implementation="sdpa", trust_remote_code=True) |
|
|
|
|
|
Qmodel = torch.quantization.quantize_dynamic( |
|
model, {torch.nn.Linear}, dtype=torch.qint8 |
|
) |
|
del model |
|
st.session_state.model = Qmodel |
|
st.session_state.model_loaded = True |
|
st.write("model loaded complete") |
|
|
|
if not st.session_state.model_loaded: |
|
with st.spinner('Loading model...'): |
|
load_model() |
|
|
|
|
|
|
|
if 'has_run' not in st.session_state: |
|
st.session_state.has_run = False |
|
|
|
|
|
st.markdown('<h3><center><b>VQA</b></center></h3>', unsafe_allow_html=True) |
|
|
|
uploaded_image = st.sidebar.file_uploader("Upload your image here", type=["jpg", "jpeg", "png"]) |
|
|
|
if uploaded_image is not None: |
|
image = Image.open(uploaded_image) |
|
if image.mode != 'RGB': |
|
image = image.convert('RGB') |
|
image = image.resize((256,256)) |
|
|
|
|
|
image_bytes = io.BytesIO() |
|
image_format = image.format if image.format else 'PNG' |
|
image.save(image_bytes, format=image_format) |
|
image_bytes.seek(0) |
|
|
|
st.image(image, caption="Uploaded Image", use_column_width=True) |
|
image_binary = image_bytes.getvalue() |
|
|
|
task_prompt = st.sidebar.text_input("Task Prompt", value="<MORE_DETAILED_CAPTION>") |
|
|
|
text_input = st.sidebar.text_area("Input Questions",value="<MORE_DETAILED_CAPTION>", height=20) |
|
|
|
if st.sidebar.button("Generate Caption", key="Generate"): |
|
|
|
|
|
output=inference.run_example(image,st.session_state.model,st.session_state.processor,task_prompt,text_input) |
|
st.write(output) |
|
|
|
|
|
|