HCFA_Process / app.py
Laskari-Naveen's picture
Update app.py
b1c1619 verified
raw
history blame
1.61 kB
import streamlit as st
import torch
from transformers import AutoProcessor, VisionEncoderDecoderModel
import requests
from PIL import Image
import pandas as pd
st.write("Processing HCFA claims")
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
processor = AutoProcessor.from_pretrained("Laskari-Naveen/HCFA_99")
model = VisionEncoderDecoderModel.from_pretrained("Laskari-Naveen/HCFA_99").to(device)
def run_prediction(image, model, processor):
pixel_values = processor(image, return_tensors="pt").pixel_values
task_prompt = "<s>"
decoder_input_ids = processor.tokenizer(task_prompt, add_special_tokens=False, return_tensors="pt").input_ids
outputs = model.generate(
pixel_values.to(device),
decoder_input_ids=decoder_input_ids.to(device),
max_length=model.decoder.config.max_position_embeddings,
early_stopping=True,
pad_token_id=processor.tokenizer.pad_token_id,
eos_token_id=processor.tokenizer.eos_token_id,
use_cache=True,
num_beams=2,
bad_words_ids=[[processor.tokenizer.unk_token_id]],
return_dict_in_generate=True,
)
# process output
prediction = processor.batch_decode(outputs.sequences)[0]
prediction = processor.token2json(prediction)
return prediction, outputs
uploaded_file = st.file_uploader("Choose a file")
if uploaded_file is not None:
content = uploaded_file.read()
st.image(uploaded_file)
image = Image.open(uploaded_file).convert("RGB")
prediction, output = run_prediction(image, model, processor)
st.dataframe(prediction, width=600)