WAVE_AI / app.py
wavesoumen's picture
start
412e4aa verified
raw
history blame
No virus
2.61 kB
import streamlit as st
from transformers import pipeline, AutoTokenizer, AutoModelForSeq2SeqLM
import nltk
# Download NLTK data
nltk.download('punkt')
# Initialize the image captioning pipeline
captioner = pipeline("image-to-text", model="Salesforce/blip-image-captioning-base")
# Load the tokenizer and model for tag generation
tokenizer = AutoTokenizer.from_pretrained("fabiochiu/t5-base-tag-generation")
model = AutoModelForSeq2SeqLM.from_pretrained("fabiochiu/t5-base-tag-generation")
# Streamlit app title
st.title("Multi-purpose Machine Learning App")
# Create tabs for different functionalities
tab1, tab2 = st.tabs(["Image Captioning", "Text Tag Generation"])
# Image Captioning Tab
with tab1:
st.header("Image Captioning")
# Input for image URL
image_url = st.text_input("Enter the URL of the image:")
# If an image URL is provided
if image_url:
try:
# Display the image
st.image(image_url, caption="Provided Image", use_column_width=True)
# Generate the caption
caption = captioner(image_url)
# Display the caption
st.write("**Generated Caption:**")
st.write(caption[0]['generated_text'])
except Exception as e:
st.error(f"An error occurred: {e}")
# Text Tag Generation Tab
with tab2:
st.header("Text Tag Generation")
# Text area for user input
text = st.text_area("Enter the text for tag extraction:", height=200)
# Button to generate tags
if st.button("Generate Tags"):
if text:
try:
# Tokenize and encode the input text
inputs = tokenizer([text], max_length=512, truncation=True, return_tensors="pt")
# Generate tags
output = model.generate(**inputs, num_beams=8, do_sample=True, min_length=10, max_length=64)
# Decode the output
decoded_output = tokenizer.batch_decode(output, skip_special_tokens=True)[0]
# Extract unique tags
tags = list(set(decoded_output.strip().split(", ")))
# Display the tags
st.write("**Generated Tags:**")
st.write(tags)
except Exception as e:
st.error(f"An error occurred: {e}")
else:
st.warning("Please enter some text to generate tags.")
# To run this app, save this code to a file (e.g., `app.py`) and run `streamlit run app.py` in your terminal.