ardavey commited on
Commit
7ab8248
·
verified ·
1 Parent(s): 25835df

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +21 -14
app.py CHANGED
@@ -23,6 +23,9 @@ uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png
23
  # Text input for prompt
24
  input_text = st.text_input("Enter your prompt:", "Identify the skin condition?")
25
 
 
 
 
26
  # Process and display the result when the button is clicked
27
  if uploaded_file is not None and st.button("Analyze"):
28
  try:
@@ -30,21 +33,25 @@ if uploaded_file is not None and st.button("Analyze"):
30
  input_image = Image.open(uploaded_file).convert("RGB")
31
  st.image(input_image, caption="Uploaded Image", use_column_width=True)
32
 
 
 
 
 
33
  # Prepare inputs
34
- inputs = processor(
35
- text=input_text,
36
- images=input_image,
37
- return_tensors="pt",
38
- padding="longest"
39
- ).to(device)
40
-
41
- # Generate output
42
- max_new_tokens = 50
43
- with torch.no_grad():
44
- outputs = model.generate(**inputs, max_new_tokens=max_new_tokens)
45
-
46
- # Decode output
47
- decoded_output = processor.decode(outputs[0], skip_special_tokens=True)
48
 
49
  # Display result
50
  st.success("Analysis Complete!")
 
23
  # Text input for prompt
24
  input_text = st.text_input("Enter your prompt:", "Identify the skin condition?")
25
 
26
+ # Slider for max tokens
27
+ max_new_tokens = st.slider("Maximum Output Tokens:", 10, 100, 50)
28
+
29
  # Process and display the result when the button is clicked
30
  if uploaded_file is not None and st.button("Analyze"):
31
  try:
 
33
  input_image = Image.open(uploaded_file).convert("RGB")
34
  st.image(input_image, caption="Uploaded Image", use_column_width=True)
35
 
36
+ # Resize image for efficiency
37
+ max_size = (512, 512)
38
+ input_image = input_image.resize(max_size)
39
+
40
  # Prepare inputs
41
+ with st.spinner("Processing..."):
42
+ inputs = processor(
43
+ text=input_text,
44
+ images=input_image,
45
+ return_tensors="pt",
46
+ padding="longest"
47
+ ).to(device)
48
+
49
+ # Generate output
50
+ with torch.no_grad():
51
+ outputs = model.generate(**inputs, max_new_tokens=max_new_tokens)
52
+
53
+ # Decode output
54
+ decoded_output = processor.decode(outputs[0], skip_special_tokens=True)
55
 
56
  # Display result
57
  st.success("Analysis Complete!")