wavesoumen commited on
Commit
412e4aa
1 Parent(s): eaea803
Files changed (1) hide show
  1. app.py +74 -0
app.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from transformers import pipeline, AutoTokenizer, AutoModelForSeq2SeqLM
3
+ import nltk
4
+
5
+ # Download NLTK data
6
+ nltk.download('punkt')
7
+
8
+ # Initialize the image captioning pipeline
9
+ captioner = pipeline("image-to-text", model="Salesforce/blip-image-captioning-base")
10
+
11
+ # Load the tokenizer and model for tag generation
12
+ tokenizer = AutoTokenizer.from_pretrained("fabiochiu/t5-base-tag-generation")
13
+ model = AutoModelForSeq2SeqLM.from_pretrained("fabiochiu/t5-base-tag-generation")
14
+
15
+ # Streamlit app title
16
+ st.title("Multi-purpose Machine Learning App")
17
+
18
+ # Create tabs for different functionalities
19
+ tab1, tab2 = st.tabs(["Image Captioning", "Text Tag Generation"])
20
+
21
+ # Image Captioning Tab
22
+ with tab1:
23
+ st.header("Image Captioning")
24
+
25
+ # Input for image URL
26
+ image_url = st.text_input("Enter the URL of the image:")
27
+
28
+ # If an image URL is provided
29
+ if image_url:
30
+ try:
31
+ # Display the image
32
+ st.image(image_url, caption="Provided Image", use_column_width=True)
33
+
34
+ # Generate the caption
35
+ caption = captioner(image_url)
36
+
37
+ # Display the caption
38
+ st.write("**Generated Caption:**")
39
+ st.write(caption[0]['generated_text'])
40
+ except Exception as e:
41
+ st.error(f"An error occurred: {e}")
42
+
43
+ # Text Tag Generation Tab
44
+ with tab2:
45
+ st.header("Text Tag Generation")
46
+
47
+ # Text area for user input
48
+ text = st.text_area("Enter the text for tag extraction:", height=200)
49
+
50
+ # Button to generate tags
51
+ if st.button("Generate Tags"):
52
+ if text:
53
+ try:
54
+ # Tokenize and encode the input text
55
+ inputs = tokenizer([text], max_length=512, truncation=True, return_tensors="pt")
56
+
57
+ # Generate tags
58
+ output = model.generate(**inputs, num_beams=8, do_sample=True, min_length=10, max_length=64)
59
+
60
+ # Decode the output
61
+ decoded_output = tokenizer.batch_decode(output, skip_special_tokens=True)[0]
62
+
63
+ # Extract unique tags
64
+ tags = list(set(decoded_output.strip().split(", ")))
65
+
66
+ # Display the tags
67
+ st.write("**Generated Tags:**")
68
+ st.write(tags)
69
+ except Exception as e:
70
+ st.error(f"An error occurred: {e}")
71
+ else:
72
+ st.warning("Please enter some text to generate tags.")
73
+
74
+ # To run this app, save this code to a file (e.g., `app.py`) and run `streamlit run app.py` in your terminal.