|
import streamlit as st |
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
import torch |
|
import re |
|
|
|
|
|
|
|
def initialize_gemini_api(api_key): |
|
"""Initializes the Gemini model and tokenizer.""" |
|
try: |
|
|
|
tokenizer = AutoTokenizer.from_pretrained("google/gemini-1.5-pro-001", token=api_key) |
|
model = AutoModelForCausalLM.from_pretrained("google/gemini-1.5-pro-001", token=api_key, device_map="auto", torch_dtype=torch.bfloat16) |
|
|
|
return model, tokenizer |
|
except Exception as e: |
|
st.error(f"Error initializing model: {e}") |
|
return None, None |
|
|
|
|
|
def preprocess_input(user_input, input_type): |
|
"""Preprocesses the input based on the input type.""" |
|
|
|
prompt_templates = { |
|
"recipe_suggestion": "I have the following ingredients: {}. Suggest a recipe, and the recipe must include the ingredients I provided. Provide steps", |
|
"promotion_idea": "Suggest a promotion to increase customer engagement based on these goals/themes: {}.", |
|
"waste_reduction_tip": "Suggest strategies, including numbered steps, to minimize food waste based on this context/these ingredients: {}.", |
|
"event_planning": "I want to plan an event. Here's the description/goals/requirements: {}. Give detailed, step-by-step instructions and important considerations.", |
|
} |
|
|
|
prompt = prompt_templates.get(input_type) |
|
if prompt: |
|
return prompt.format(user_input) |
|
else: |
|
return "Invalid input type." |
|
|
|
def generate_suggestion(model, tokenizer, processed_input): |
|
"""Generates text using the Gemini model.""" |
|
|
|
try: |
|
input_ids = tokenizer(processed_input, return_tensors="pt").to(model.device) |
|
outputs = model.generate(**input_ids, max_new_tokens=512, temperature=0.7, top_k=50, top_p=0.95, do_sample=True) |
|
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
return generated_text |
|
except Exception as e: |
|
st.error(f"Error during generation: {e}") |
|
return "An error occurred during suggestion generation." |
|
|
|
|
|
def postprocess_output(raw_response, input_type): |
|
"""Postprocesses the generated text.""" |
|
|
|
|
|
cleaned_response = raw_response.strip() |
|
|
|
if input_type == 'recipe_suggestion': |
|
try: |
|
pass |
|
except: |
|
pass |
|
|
|
elif input_type == 'promotion_idea': |
|
try: |
|
pass |
|
except: |
|
pass |
|
elif input_type == "waste_reduction_tip" or input_type == 'event_planning': |
|
try: |
|
|
|
pass |
|
except: |
|
pass |
|
|
|
|
|
sentences = re.split(r'(?<!\w\.\w.)(?<![A-Z][a-z]\.)(?<=\.|\?)\s', cleaned_response) |
|
formatted_response = "\n\n".join(sentences) |
|
return formatted_response |
|
def get_ai_suggestion(user_input, input_type, api_key): |
|
model, tokenizer = initialize_gemini_api(api_key) |
|
if model is None or tokenizer is None: |
|
return "Failed to initialize the model. Check your API key." |
|
processed_input = preprocess_input(user_input, input_type) |
|
raw_response = generate_suggestion(model, tokenizer, processed_input) |
|
formatted_response = postprocess_output(raw_response, input_type) |
|
return formatted_response |
|
|
|
|
|
|
|
|
|
|
|
st.set_page_config(page_title="AI Restaurant Assistant", layout="wide") |
|
st.sidebar.title("AI Restaurant Assistant") |
|
|
|
|
|
|
|
|
|
if 'api_key' not in st.session_state: |
|
st.session_state.api_key = '' |
|
|
|
|
|
|
|
api_key_input = st.sidebar.text_input("Enter your Hugging Face API key:", type="password", value=st.session_state.api_key) |
|
|
|
if api_key_input: |
|
st.session_state.api_key = api_key_input |
|
|
|
if not st.session_state.api_key: |
|
st.sidebar.warning("AIzaSyBsHbB6QcwYSjZ7GeGmNuUkmnTm6a2BjmM") |
|
st.stop() |
|
|
|
|
|
input_type = st.sidebar.selectbox("What kind of suggestion do you need?", |
|
["recipe_suggestion", "promotion_idea", "waste_reduction_tip", "event_planning"]) |
|
|
|
|
|
st.title("Get AI-Powered Suggestions") |
|
st.write("This tool leverages the power of the Gemini 1.5 Pro model to assist with various restaurant management tasks.") |
|
user_input = st.text_area("Enter your input here:", height=150, key="user_input") |
|
|
|
if st.button("Generate Suggestion"): |
|
if user_input: |
|
with st.spinner("Generating suggestion..."): |
|
suggestion = get_ai_suggestion(user_input, input_type, st.session_state.api_key) |
|
st.markdown("### AI Suggestion:", unsafe_allow_html=True) |
|
st.write(suggestion) |
|
else: |
|
st.warning("Please enter some input.") |