Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| from PIL import Image | |
| import pandas as pd | |
| import numpy as np | |
| from model_methods import predict | |
| import base64 # for title image | |
| from load_css import local_css # for highlighting text | |
| # cached so that latency for subsequent runs are shorter | |
| def import_nltk(): | |
| import nltk | |
| nltk.download('wordnet') | |
| nltk.download('omw-1.4') | |
| # configuration of the page | |
| st.set_page_config( | |
| layout='centered', | |
| page_icon=Image.open('subreddit_icon.png'), | |
| page_title='Marvel vs. DC comics', | |
| initial_sidebar_state='auto' | |
| ) | |
| # embed source link in title image using base64 module | |
| # reference: https://discuss.streamlit.io/t/how-to-show-local-gif-image/3408/4 | |
| # reference: https://discuss.streamlit.io/t/local-image-button/5409/4 | |
| im = open("subreddit_icon.png", "rb") | |
| contents = im.read() | |
| im_base64 = base64.b64encode(contents).decode("utf-8") | |
| im.close() | |
| html = f'''<a href='https://www.reddit.com/'> | |
| <img src='data:image/png;base64,{im_base64}' width='100'> | |
| </a><figcaption>Credit: reddit.com</figcaption>''' | |
| st.markdown(html, unsafe_allow_html=True) | |
| st.title('Subreddit Post classifier') | |
| local_css("highlight_text.css") | |
| text = '''The algorithm driving this app is built using subreddit posts published | |
| between April and July 2022. It is only able to classify between | |
| <span class='highlight blue'> **Marvel** </span> | |
| and | |
| <span class='highlight blue'> **DC Comics** </span> | |
| subreddits.''' | |
| st.markdown(text, unsafe_allow_html=True) | |
| # Area for text input | |
| import_nltk() # import nltk module if not yet cached in local computer | |
| new_post = st.text_input('Please copy and paste the subreddit post here', '') | |
| # process new input | |
| def predict_post(): | |
| data = pd.Series(new_post) # pd.Series format new input coz that is the format that predict() recognises | |
| result = predict(data) | |
| if result == 1: | |
| post = 'Marvel' | |
| if result == 0: | |
| post = 'DC comics' | |
| st.write(f'### This post belongs to') | |
| st.success(f'# {post}') | |
| st.write(f'### subreddit') | |
| # instantiate submit button | |
| if st.button('Submit'): | |
| with st.sidebar: | |
| try: | |
| predict_post() | |
| except: | |
| st.warning(''' | |
| Unable to detect text. | |
| Please enter text for prediction. | |
| \n\n Thank you π. | |
| ''') | |