Laskari-Naveen commited on
Commit
0a31f9a
·
verified ·
1 Parent(s): 49c0fc1

create app.py

Browse files
Files changed (1) hide show
  1. app.py +52 -0
app.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import torch
3
+ from transformers import AutoProcessor, VisionEncoderDecoderModel
4
+ import requests
5
+ from PIL import Image
6
+ import pandas as pd
7
+
8
+ st.write("Processing HCFA claims")
9
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
10
+ processor = AutoProcessor.from_pretrained("Laskari-Naveen/HCFA_99")
11
+ model = VisionEncoderDecoderModel.from_pretrained("Laskari-Naveen/HCFA_99").to(device)
12
+
13
+
14
+ def run_prediction(image, model, processor):
15
+ pixel_values = processor(image, return_tensors="pt").pixel_values
16
+ task_prompt = "<s>"
17
+ decoder_input_ids = processor.tokenizer(task_prompt, add_special_tokens=False, return_tensors="pt").input_ids
18
+ outputs = model.generate(
19
+ pixel_values.to(device),
20
+ decoder_input_ids=decoder_input_ids.to(device),
21
+ max_length=model.decoder.config.max_position_embeddings,
22
+ early_stopping=True,
23
+ pad_token_id=processor.tokenizer.pad_token_id,
24
+ eos_token_id=processor.tokenizer.eos_token_id,
25
+ use_cache=True,
26
+ num_beams=2,
27
+ bad_words_ids=[[processor.tokenizer.unk_token_id]],
28
+ return_dict_in_generate=True,
29
+ )
30
+ # process output
31
+ prediction = processor.batch_decode(outputs.sequences)[0]
32
+ prediction = processor.token2json(prediction)
33
+ return prediction, outputs
34
+
35
+ def split_and_expand(row):
36
+ if row['Key'] == "33_Missing_Teeth":
37
+ keys = [row['Key']]
38
+ values = row['Value'].split(';')[0]
39
+ else:
40
+ keys = [row['Key']] * len(row['Value'].split(';'))
41
+ values = row['Value'].split(';')
42
+ return pd.DataFrame({'Key': keys, 'Value': values})
43
+
44
+
45
+ uploaded_file = st.file_uploader("Choose a file")
46
+ if uploaded_file is not None:
47
+ content = uploaded_file.read()
48
+ st.image(uploaded_file)
49
+ image = Image.open(uploaded_file).convert("RGB")
50
+ prediction, output = run_prediction(image, model, processor)
51
+
52
+ st.dataframe(prediction, width=600)