Spaces:
Sleeping
Sleeping
import streamlit as st | |
import torch | |
import numpy as np | |
from transformers import AutoTokenizer, AutoModelForSequenceClassification | |
# Load the model and tokenizer from Hugging Face | |
model = AutoModelForSequenceClassification.from_pretrained("Caseyishere/StoryCraft", num_labels=5) | |
tokenizer = AutoTokenizer.from_pretrained("Caseyishere/StoryCraft") | |
# Streamlit app interface | |
st.set_page_config(page_title="Story Craft", page_icon="🍽️", layout="centered") | |
# Set page title and styles | |
st.title("🍽️ Welcome to Story Craft 🍽️") | |
st.markdown(""" | |
<style> | |
.big-font { | |
font-size:24px !important; | |
font-weight:bold; | |
} | |
.highlight { | |
color: #FF4B4B; | |
} | |
.divider { | |
border-top: 2px solid #bbb; | |
margin: 20px 0; | |
} | |
.menu { | |
font-size:18px !important; | |
line-height: 1.8; | |
font-family: 'Arial', sans-serif; | |
} | |
</style> | |
""", unsafe_allow_html=True) | |
# Get user input | |
user_input = st.text_input("Please tell us what you like today:") | |
if user_input: | |
# Preprocess the input using the tokenizer | |
inputs = tokenizer(user_input, padding=True, truncation=True, return_tensors='pt') | |
# Get predictions from the model | |
outputs = model(**inputs) | |
predictions = torch.nn.functional.softmax(outputs.logits, dim=-1) | |
predictions = predictions.cpu().detach().numpy() | |
# Get the predicted label | |
predicted_label = np.argmax(predictions) | |
# Display the predicted label with its corresponding sentiment | |
label_map = {0: "Negative", 1: "Neutral", 2: "Positive"} | |
# Display the predicted label and corresponding sentiment | |
st.write(f"Predicted label is {predicted_label} ({label_map.get(predicted_label, 'Unknown')} Sentence)") | |
# Generate response based on predicted label | |
responses = { | |
0: '''**Appetizer**: Escargots: Snails cooked in garlic butter with herbs | |
**Main Course**: Coq au vin: Chicken braised in red wine with mushrooms and onions | |
**Side Dish**: Pommes frites: French fries | |
**Dessert**: Crème brûlée: Custard topped with caramelized sugar | |
**Beverage**: Bordeaux: A red wine from the Bordeaux region of France | |
**Cheese Course**: Fromage à raclette: Melted cheese served with bread, potatoes, and pickles''', | |
1: '''**Appetizer**: Spätzle: Swabian egg noodles with cheese | |
**Main Course**: Wiener schnitzel: Breaded veal cutlet | |
**Side Dish**: Sauerkraut: Fermented cabbage | |
**Dessert**: Schwarzwälder Kirschtorte: Black Forest cake | |
**Beverage**: Kölsch: A light, golden ale from Cologne | |
**Cheese Course**: Käsekuchen: German cheesecake''', | |
2: '''**Appetizer**: Creamy Spinach and Artichoke Dip with tortilla chips | |
**Main Course**: Ribeye Steak cooked to your desired temperature (medium-rare, medium, well-done) | |
**Side Dish**: Baked Potato topped with butter, sour cream, and bacon bits | |
**Dessert**: Chocolate Lava Cake with vanilla ice cream | |
**Beverage**: Red Wine (ask your server for a recommendation based on your preferences) | |
**Salad**: Caesar Salad with romaine lettuce, croutons, Parmesan cheese, and Caesar dressing | |
**Soup**: French Onion Soup with caramelized onions, Gruyère cheese, and croutons''', | |
3: "Oops! Something went wrong!" | |
} | |
# Display the response based on the predicted label | |
st.markdown('<div class="divider"></div>', unsafe_allow_html=True) | |
st.markdown(f'<div class="big-font highlight">Here is your curated menu based on your input:</div>', unsafe_allow_html=True) | |
# Correcting the misplaced closing parenthesis | |
st.write(responses.get(predicted_label, "I'm not sure what you're asking for.")) | |
# Add a separator | |
st.markdown('<div class="divider"></div>', unsafe_allow_html=True) | |