| | import streamlit as st |
| | from transformers import AutoModelForCausalLM, AutoTokenizer |
| | import torch |
| | import re |
| |
|
| | |
| |
|
| | def initialize_gemini_api(api_key): |
| | """Initializes the Gemini model and tokenizer.""" |
| | try: |
| | |
| | tokenizer = AutoTokenizer.from_pretrained("google/gemini-1.5-pro-001", token=api_key) |
| | model = AutoModelForCausalLM.from_pretrained("google/gemini-1.5-pro-001", token=api_key, device_map="auto", torch_dtype=torch.bfloat16) |
| |
|
| | return model, tokenizer |
| | except Exception as e: |
| | st.error(f"Error initializing model: {e}") |
| | return None, None |
| |
|
| |
|
| | def preprocess_input(user_input, input_type): |
| | """Preprocesses the input based on the input type.""" |
| |
|
| | prompt_templates = { |
| | "recipe_suggestion": "I have the following ingredients: {}. Suggest a recipe, and the recipe must include the ingredients I provided. Provide steps", |
| | "promotion_idea": "Suggest a promotion to increase customer engagement based on these goals/themes: {}.", |
| | "waste_reduction_tip": "Suggest strategies, including numbered steps, to minimize food waste based on this context/these ingredients: {}.", |
| | "event_planning": "I want to plan an event. Here's the description/goals/requirements: {}. Give detailed, step-by-step instructions and important considerations.", |
| | } |
| |
|
| | prompt = prompt_templates.get(input_type) |
| | if prompt: |
| | return prompt.format(user_input) |
| | else: |
| | return "Invalid input type." |
| |
|
| | def generate_suggestion(model, tokenizer, processed_input): |
| | """Generates text using the Gemini model.""" |
| |
|
| | try: |
| | input_ids = tokenizer(processed_input, return_tensors="pt").to(model.device) |
| | outputs = model.generate(**input_ids, max_new_tokens=512, temperature=0.7, top_k=50, top_p=0.95, do_sample=True) |
| | generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True) |
| | return generated_text |
| | except Exception as e: |
| | st.error(f"Error during generation: {e}") |
| | return "An error occurred during suggestion generation." |
| |
|
| |
|
| | def postprocess_output(raw_response, input_type): |
| | """Postprocesses the generated text.""" |
| |
|
| | |
| | cleaned_response = raw_response.strip() |
| | |
| | if input_type == 'recipe_suggestion': |
| | try: |
| | pass |
| | except: |
| | pass |
| |
|
| | elif input_type == 'promotion_idea': |
| | try: |
| | pass |
| | except: |
| | pass |
| | elif input_type == "waste_reduction_tip" or input_type == 'event_planning': |
| | try: |
| | |
| | pass |
| | except: |
| | pass |
| |
|
| | |
| | sentences = re.split(r'(?<!\w\.\w.)(?<![A-Z][a-z]\.)(?<=\.|\?)\s', cleaned_response) |
| | formatted_response = "\n\n".join(sentences) |
| | return formatted_response |
| | def get_ai_suggestion(user_input, input_type, api_key): |
| | model, tokenizer = initialize_gemini_api(api_key) |
| | if model is None or tokenizer is None: |
| | return "Failed to initialize the model. Check your API key." |
| | processed_input = preprocess_input(user_input, input_type) |
| | raw_response = generate_suggestion(model, tokenizer, processed_input) |
| | formatted_response = postprocess_output(raw_response, input_type) |
| | return formatted_response |
| |
|
| |
|
| |
|
| | |
| |
|
| | st.set_page_config(page_title="AI Restaurant Assistant", layout="wide") |
| | st.sidebar.title("AI Restaurant Assistant") |
| |
|
| | |
| |
|
| | |
| | if 'api_key' not in st.session_state: |
| | st.session_state.api_key = '' |
| | |
| | |
| | |
| | api_key_input = st.sidebar.text_input("Enter your Hugging Face API key:", type="password", value=st.session_state.api_key) |
| |
|
| | if api_key_input: |
| | st.session_state.api_key = api_key_input |
| |
|
| | if not st.session_state.api_key: |
| | st.sidebar.warning("AIzaSyBsHbB6QcwYSjZ7GeGmNuUkmnTm6a2BjmM") |
| | st.stop() |
| |
|
| | |
| | input_type = st.sidebar.selectbox("What kind of suggestion do you need?", |
| | ["recipe_suggestion", "promotion_idea", "waste_reduction_tip", "event_planning"]) |
| |
|
| | |
| | st.title("Get AI-Powered Suggestions") |
| | st.write("This tool leverages the power of the Gemini 1.5 Pro model to assist with various restaurant management tasks.") |
| | user_input = st.text_area("Enter your input here:", height=150, key="user_input") |
| |
|
| | if st.button("Generate Suggestion"): |
| | if user_input: |
| | with st.spinner("Generating suggestion..."): |
| | suggestion = get_ai_suggestion(user_input, input_type, st.session_state.api_key) |
| | st.markdown("### AI Suggestion:", unsafe_allow_html=True) |
| | st.write(suggestion) |
| | else: |
| | st.warning("Please enter some input.") |