# import streamlit as st | |
# import transformers | |
# # Load the pre-trained language model | |
# model_name = "bert-base-uncased" | |
# model = transformers.pipeline("text-classification", model=model_name) | |
# # Streamlit App | |
# def main(): | |
# st.title("Sentence Category Classifier") | |
# # Input search sentence | |
# search_query = st.text_input("Enter a sentence:") | |
# result = "" | |
# # Process the search sentence when the user clicks the Search button | |
# if st.button("Search"): | |
# if search_query: | |
# # Classify the sentence using the pre-trained model | |
# categories = classify_sentence(search_query) | |
# # Display the categories as output | |
# if categories: | |
# result = f"The sentence belongs to the following categories:\n\n" | |
# for category in categories: | |
# result += f"• {category}\n" | |
# else: | |
# result = "No categories found for the sentence." | |
# # Display the result | |
# st.text(result) | |
# # Function to classify the sentence using the pre-trained language model | |
# @st.cache(allow_output_mutation=True) | |
# def classify_sentence(query): | |
# # Classify the sentence using the pre-trained model | |
# categories = model(query) | |
# # Extract the category labels from the model's output | |
# category_labels = [category['label'] for category in categories] | |
# return category_labels | |
# if __name__ == "__main__": | |
# main() | |
import streamlit as st | |
from sentence_transformers import SentenceTransformer | |
from sklearn.metrics.pairwise import cosine_similarity | |
# Load pre-trained Sentence Transformer model | |
model = SentenceTransformer('bert-base-nli-mean-tokens') | |
# Define your dataset here (example categories) | |
categories = { | |
'sports': ['football', 'basketball', 'tennis'], | |
'politics': ['election', 'government', 'policy'], | |
'technology': ['AI', 'machine learning', 'data science'] | |
} | |
# Function to get relevant categories based on user query | |
def get_relevant_categories(query): | |
query_embedding = model.encode([query]) | |
category_scores = {} | |
for category, keywords in categories.items(): | |
keyword_embeddings = model.encode(keywords) | |
similarity_scores = cosine_similarity(query_embedding, keyword_embeddings) | |
category_scores[category] = sum(similarity_scores)[0] | |
relevant_categories = [category for category, score in sorted(category_scores.items(), key=lambda x: x[1], reverse=True) if score > 0] | |
return relevant_categories | |
# Streamlit app layout and UI | |
def main(): | |
st.title("Sentence Categorization App") | |
st.write("Enter a sentence to categorize:") | |
user_input = st.text_input('', value='', max_chars=None, key=None, type='default') | |
if st.button('Categorize'): | |
if user_input: | |
relevant_categories = get_relevant_categories(user_input) | |
st.write("Relevant Categories:") | |
for category in relevant_categories: | |
st.write(f"- {category}") | |
else: | |
st.write("Please enter a sentence for categorization.") | |
if __name__ == "__main__": | |
main() |