File size: 3,404 Bytes
1efea19
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import streamlit as st
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import base64

st.set_page_config(page_title="Gemma Demo", layout="wide")
# Model selection (STUBBED behavior)
model_option = st.selectbox(
    "Choose a Gemma to reveal hidden truths:",
    ["gemma-2b-it (Instruct)", "gemma-2b", "gemma-7b", "gemma-7b-it"],
    index=0,
    help="Stubbed selection – only gemma-2b-it will load for now."
)
st.markdown("<h1 style='text-align: center;'>Portal to Gemma</h1>", unsafe_allow_html=True)

# Load both GIFs in base64 format
def load_gif_base64(path):
    with open(path, "rb") as f:
        return base64.b64encode(f.read()).decode("utf-8")

still_gem_b64 = load_gif_base64("assets/stillGem.gif")
rotating_gem_b64 = load_gif_base64("assets/rotatingGem.gif")

# Placeholder for GIF HTML
gif_html = st.empty()
caption = st.empty()

# Initially show still gem
gif_html.markdown(
    f"<div style='text-align:center;'><img src='data:image/gif;base64,{still_gem_b64}' width='300'></div>",
    unsafe_allow_html=True,
)

@st.cache_resource
def load_model():
    model_id = "google/gemma-2b-it"
    tokenizer = AutoTokenizer.from_pretrained(model_id, token=True)
    model = AutoModelForCausalLM.from_pretrained(
        model_id,
        device_map=None,
        torch_dtype=torch.float32
    )
    model.to("cpu")
    return tokenizer, model

tokenizer, model = load_model()
prompt = st.text_area("Enter your prompt:", "What is Gemma?")
# # Example prompt selector
# examples = {
#     "🧠 Summary": "Summarize the history of AI in 5 bullet points.",
#     "💻 Code": "Write a Python function to sort a list using bubble sort.",
#     "📜 Poem": "Write a haiku about large language models.",
#     "🤖 Explain": "Explain what a transformer is in simple terms.",
#     "🔍 Fact": "Who won the FIFA World Cup in 2022?"
# }

# selected_example = st.selectbox("Choose a Gemma to consult:", list(examples.keys()) + ["✍️ Custom input"])
# Add before generation
col1, col2, col3 = st.columns(3)

with col1:
    temperature = st.slider("Temperature", 0.1, 1.5, 1.0)

with col2:
    max_tokens = st.slider("Max tokens", 50, 500, 100)

with col3:
    top_p = st.slider("Top-p (nucleus sampling)", 0.1, 1.0, 0.95)
# if selected_example != "✍️ Custom input":
#     prompt = examples[selected_example]
# else:
#     prompt = st.text_area("Enter your prompt:")

if st.button("Generate"):
    # Swap to rotating GIF
    gif_html.markdown(
        f"<div style='text-align:center;'><img src='data:image/gif;base64,{rotating_gem_b64}' width='300'></div>",
        unsafe_allow_html=True,
    )
    caption.markdown("<p style='text-align: center;'>Gemma is thinking... 🌀</p>", unsafe_allow_html=True)


    # Generate text

    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
    with torch.no_grad():
        outputs = model.generate(**inputs, max_new_tokens=max_tokens, temperature=temperature, top_p=top_p)

    # Back to still
    gif_html.markdown(
        f"<div style='text-align:center;'><img src='data:image/gif;base64,{still_gem_b64}' width='300'></div>",
        unsafe_allow_html=True,
    )
    caption.empty()


    result = tokenizer.decode(outputs[0], skip_special_tokens=True)
    st.markdown("### ✨ Output:")
    st.write(result)