import streamlit as st
import torch
import numpy as np
from transformers import AutoTokenizer, AutoModelForSequenceClassification

# Load the model and tokenizer from Hugging Face
model = AutoModelForSequenceClassification.from_pretrained("Caseyishere/StoryCraft", num_labels=5)
tokenizer = AutoTokenizer.from_pretrained("Caseyishere/StoryCraft")

# Streamlit app interface
st.set_page_config(page_title="Story Craft", page_icon="🍽️", layout="centered")

# Set page title and styles
st.title("🍽️ Welcome to Story Craft 🍽️")
st.markdown("""
    <style>
    .big-font {
        font-size:24px !important;
        font-weight:bold;
    }
    .highlight {
        color: #FF4B4B;
    }
    .divider {
        border-top: 2px solid #bbb;
        margin: 20px 0;
    }
    .menu {
        font-size:18px !important;
        line-height: 1.8;
        font-family: 'Arial', sans-serif;
    }
    </style>
    """, unsafe_allow_html=True)

# Get user input
user_input = st.text_input("Please tell us what you like today:")

if user_input:
    # Preprocess the input using the tokenizer
    inputs = tokenizer(user_input, padding=True, truncation=True, return_tensors='pt')

    # Get predictions from the model
    outputs = model(**inputs)
    predictions = torch.nn.functional.softmax(outputs.logits, dim=-1)
    predictions = predictions.cpu().detach().numpy()

    # Get the predicted label
    predicted_label = np.argmax(predictions)

    # Display the predicted label with its corresponding sentiment
    label_map = {0: "Negative", 1: "Neutral", 2: "Positive"}

    # Display the predicted label and corresponding sentiment
    st.write(f"Predicted label is {predicted_label} ({label_map.get(predicted_label, 'Unknown')} Sentence)")

    # Generate response based on predicted label
    responses = {
        0: '''**Appetizer**: Escargots: Snails cooked in garlic butter with herbs  
          **Main Course**: Coq au vin: Chicken braised in red wine with mushrooms and onions  
          **Side Dish**: Pommes frites: French fries  
          **Dessert**: Crème brûlée: Custard topped with caramelized sugar  
          **Beverage**: Bordeaux: A red wine from the Bordeaux region of France  
          **Cheese Course**: Fromage à raclette: Melted cheese served with bread, potatoes, and pickles''',
        1: '''**Appetizer**: Spätzle: Swabian egg noodles with cheese  
          **Main Course**: Wiener schnitzel: Breaded veal cutlet  
          **Side Dish**: Sauerkraut: Fermented cabbage  
          **Dessert**: Schwarzwälder Kirschtorte: Black Forest cake  
          **Beverage**: Kölsch: A light, golden ale from Cologne  
          **Cheese Course**: Käsekuchen: German cheesecake''',
        2: '''**Appetizer**: Creamy Spinach and Artichoke Dip with tortilla chips  
          **Main Course**: Ribeye Steak cooked to your desired temperature (medium-rare, medium, well-done)  
          **Side Dish**: Baked Potato topped with butter, sour cream, and bacon bits  
          **Dessert**: Chocolate Lava Cake with vanilla ice cream  
          **Beverage**: Red Wine (ask your server for a recommendation based on your preferences)  
          **Salad**: Caesar Salad with romaine lettuce, croutons, Parmesan cheese, and Caesar dressing  
**Soup**: French Onion Soup with caramelized onions, Gruyère cheese, and croutons''',
        3: "Oops! Something went wrong!"
    }

    # Display the response based on the predicted label
    st.markdown('<div class="divider"></div>', unsafe_allow_html=True)
    st.markdown(f'<div class="big-font highlight">Here is your curated menu based on your input:</div>', unsafe_allow_html=True)

    # Correcting the misplaced closing parenthesis
    st.write(responses.get(predicted_label, "I'm not sure what you're asking for."))

    # Add a separator
    st.markdown('<div class="divider"></div>', unsafe_allow_html=True)