|
import torch |
|
import streamlit as st |
|
import requests |
|
from pathlib import Path |
|
|
|
|
|
MODEL_PATH = "pytorch_model.pth" |
|
MODEL_URL = "https://huggingface.co/zongzhuofan/co-detr-vit-large-coco/resolve/main/pytorch_model.pth" |
|
|
|
|
|
@st.cache_resource |
|
def download_model(): |
|
if not Path(MODEL_PATH).exists(): |
|
with st.spinner("Downloading model... This might take a few minutes..."): |
|
response = requests.get(MODEL_URL, stream=True) |
|
with open(MODEL_PATH, "wb") as f: |
|
for chunk in response.iter_content(chunk_size=8192): |
|
if chunk: |
|
f.write(chunk) |
|
return MODEL_PATH |
|
|
|
|
|
model = YourModelClass() |
|
model_path = download_model() |
|
model.load_state_dict(torch.load(model_path, map_location='cpu')) |
|
model.eval() |
|
|
|
st.title("Co-DETR Model") |
|
|
|
uploaded_file = st.file_uploader("Upload an image", type=["jpg", "png"]) |
|
if uploaded_file is not None: |
|
|
|
input_data = preprocess_image(uploaded_file) |
|
with torch.no_grad(): |
|
output = model(input_data) |
|
|
|
st.write("Output:", output) |