ESPnet
Afar
code
File size: 5,779 Bytes
317ffbb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e7be622
317ffbb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
import streamlit as st
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import re

# --- Backend Functions ---

def initialize_gemini_api(api_key):
    """Initializes the Gemini model and tokenizer."""
    try:
        # Using Auto Classes is generally recommended for loading from Hugging Face
        tokenizer = AutoTokenizer.from_pretrained("google/gemini-1.5-pro-001", token=api_key) #check if model has a tokenizer and version number.
        model = AutoModelForCausalLM.from_pretrained("google/gemini-1.5-pro-001", token=api_key, device_map="auto", torch_dtype=torch.bfloat16) #Added model device and dtype.

        return model, tokenizer
    except Exception as e:
        st.error(f"Error initializing model: {e}")
        return None, None


def preprocess_input(user_input, input_type):
    """Preprocesses the input based on the input type."""

    prompt_templates = {
        "recipe_suggestion": "I have the following ingredients: {}. Suggest a recipe, and the recipe must include the ingredients I provided. Provide steps",
        "promotion_idea": "Suggest a promotion to increase customer engagement based on these goals/themes: {}.",
        "waste_reduction_tip": "Suggest strategies, including numbered steps, to minimize food waste based on this context/these ingredients: {}.",
        "event_planning": "I want to plan an event. Here's the description/goals/requirements: {}.  Give detailed, step-by-step instructions and important considerations.",
    }

    prompt = prompt_templates.get(input_type)
    if prompt:
        return prompt.format(user_input)
    else:
        return "Invalid input type."  # Should ideally never happen due to Streamlit UI controls.

def generate_suggestion(model, tokenizer, processed_input):
    """Generates text using the Gemini model."""

    try:
        input_ids = tokenizer(processed_input, return_tensors="pt").to(model.device) # Make sure tensors are on same device
        outputs = model.generate(**input_ids, max_new_tokens=512, temperature=0.7, top_k=50, top_p=0.95, do_sample=True) # Added important params for generation quality
        generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
        return generated_text
    except Exception as e:
        st.error(f"Error during generation: {e}")
        return "An error occurred during suggestion generation."


def postprocess_output(raw_response, input_type):
    """Postprocesses the generated text."""

    # Remove any leading/trailing whitespace
    cleaned_response = raw_response.strip()
    # Further, specific postprocessing according to context
    if input_type == 'recipe_suggestion':
      try:
        pass # Can add custom filtering
      except:
        pass

    elif input_type == 'promotion_idea':
      try:
        pass #Can add custom regex and filters
      except:
        pass
    elif input_type == "waste_reduction_tip" or input_type == 'event_planning':
        try:
           # Check to ensure instructions and steps in final output.
          pass
        except:
            pass

    # Basic example: Split into sentences for better readability (can be improved)
    sentences = re.split(r'(?<!\w\.\w.)(?<![A-Z][a-z]\.)(?<=\.|\?)\s', cleaned_response)
    formatted_response = "\n\n".join(sentences)
    return formatted_response
def get_ai_suggestion(user_input, input_type, api_key):
    model, tokenizer = initialize_gemini_api(api_key)
    if model is None or tokenizer is None:
        return "Failed to initialize the model.  Check your API key."
    processed_input = preprocess_input(user_input, input_type)
    raw_response = generate_suggestion(model, tokenizer, processed_input)
    formatted_response = postprocess_output(raw_response, input_type)
    return formatted_response



# --- Streamlit Frontend ---

st.set_page_config(page_title="AI Restaurant Assistant", layout="wide") #Set page config
st.sidebar.title("AI Restaurant Assistant")

# --- API KEY HANDLING ---

# Use st.session_state to persist the API key *only for the session*
if 'api_key' not in st.session_state:
    st.session_state.api_key = ''
# IMPORTANT SECURITY NOTE:  This method is suitable for demonstration/local development.
# For a production deployment, you MUST use a more secure method of storing the API key,
# such as environment variables and NEVER hardcode it or commit it to version control.
api_key_input = st.sidebar.text_input("Enter your Hugging Face API key:", type="password", value=st.session_state.api_key)

if api_key_input:
    st.session_state.api_key = api_key_input #Value is automatically cached and input bar has api_key once entered.

if not st.session_state.api_key:
    st.sidebar.warning("AIzaSyBsHbB6QcwYSjZ7GeGmNuUkmnTm6a2BjmM")
    st.stop()  # Stop execution if no API key

# --- Input Selection ---
input_type = st.sidebar.selectbox("What kind of suggestion do you need?",
                                  ["recipe_suggestion", "promotion_idea", "waste_reduction_tip", "event_planning"])

# --- Main Area ---
st.title("Get AI-Powered Suggestions")
st.write("This tool leverages the power of the Gemini 1.5 Pro model to assist with various restaurant management tasks.") # Introduction and description
user_input = st.text_area("Enter your input here:", height=150, key="user_input") #Key is added

if st.button("Generate Suggestion"):
    if user_input:
        with st.spinner("Generating suggestion..."):
            suggestion = get_ai_suggestion(user_input, input_type, st.session_state.api_key)
            st.markdown("### AI Suggestion:", unsafe_allow_html=True) #Style output and enhance it visually.
            st.write(suggestion) #Can upgrade output design by having boxes etc.
    else:
        st.warning("Please enter some input.")