import streamlit as st from transformers import AutoModelForCausalLM, AutoTokenizer import torch import re # --- Backend Functions --- def initialize_gemini_api(api_key): """Initializes the Gemini model and tokenizer.""" try: # Using Auto Classes is generally recommended for loading from Hugging Face tokenizer = AutoTokenizer.from_pretrained("google/gemini-1.5-pro-001", token=api_key) #check if model has a tokenizer and version number. model = AutoModelForCausalLM.from_pretrained("google/gemini-1.5-pro-001", token=api_key, device_map="auto", torch_dtype=torch.bfloat16) #Added model device and dtype. return model, tokenizer except Exception as e: st.error(f"Error initializing model: {e}") return None, None def preprocess_input(user_input, input_type): """Preprocesses the input based on the input type.""" prompt_templates = { "recipe_suggestion": "I have the following ingredients: {}. Suggest a recipe, and the recipe must include the ingredients I provided. Provide steps", "promotion_idea": "Suggest a promotion to increase customer engagement based on these goals/themes: {}.", "waste_reduction_tip": "Suggest strategies, including numbered steps, to minimize food waste based on this context/these ingredients: {}.", "event_planning": "I want to plan an event. Here's the description/goals/requirements: {}. Give detailed, step-by-step instructions and important considerations.", } prompt = prompt_templates.get(input_type) if prompt: return prompt.format(user_input) else: return "Invalid input type." # Should ideally never happen due to Streamlit UI controls. def generate_suggestion(model, tokenizer, processed_input): """Generates text using the Gemini model.""" try: input_ids = tokenizer(processed_input, return_tensors="pt").to(model.device) # Make sure tensors are on same device outputs = model.generate(**input_ids, max_new_tokens=512, temperature=0.7, top_k=50, top_p=0.95, do_sample=True) # Added important params for generation quality generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True) return generated_text except Exception as e: st.error(f"Error during generation: {e}") return "An error occurred during suggestion generation." def postprocess_output(raw_response, input_type): """Postprocesses the generated text.""" # Remove any leading/trailing whitespace cleaned_response = raw_response.strip() # Further, specific postprocessing according to context if input_type == 'recipe_suggestion': try: pass # Can add custom filtering except: pass elif input_type == 'promotion_idea': try: pass #Can add custom regex and filters except: pass elif input_type == "waste_reduction_tip" or input_type == 'event_planning': try: # Check to ensure instructions and steps in final output. pass except: pass # Basic example: Split into sentences for better readability (can be improved) sentences = re.split(r'(?