File size: 5,149 Bytes
1efea19
6d3a2f2
1efea19
 
 
d8f007f
1efea19
d8f007f
 
 
 
 
 
 
1efea19
 
 
 
 
 
0c1dfa2
 
1efea19
 
 
 
 
 
8e1d1e9
 
 
 
1efea19
d8f007f
1efea19
 
 
 
 
9b7d4f4
1b9ef5d
9b7d4f4
8020e39
ea616ac
f52ef7c
6d3a2f2
 
9b7d4f4
1efea19
 
8020e39
 
 
aeeecba
6d3a2f2
 
1efea19
8020e39
1efea19
 
 
d8f007f
 
1efea19
 
 
 
 
 
 
 
 
 
d8f007f
1efea19
d8f007f
 
1efea19
d8f007f
 
1efea19
d8f007f
 
1efea19
 
 
 
 
 
 
8e1d1e9
 
 
 
1efea19
d8f007f
1efea19
 
d8f007f
1efea19
 
 
 
 
 
8020e39
 
 
 
 
1efea19
 
8e1d1e9
 
 
 
1efea19
d8f007f
1efea19
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
import streamlit as st
from transformers import AutoModelForCausalLM, AutoTokenizer, FineGrainedFP8Config
import torch
import base64

st.set_page_config(page_title="LIA Demo", layout="wide")
# Model selection (STUBBED behavior)
# model_option = st.selectbox(
#     "Choose a Gemma to reveal hidden truths:",
#     ["gemma-2b-it (Instruct)", "gemma-2b", "gemma-7b", "gemma-7b-it"],
#     index=0,
#     help="Stubbed selection – only gemma-2b-it will load for now."
# )
st.markdown("<h1 style='text-align: center;'>Ask LeoNardo!</h1>", unsafe_allow_html=True)

# Load both GIFs in base64 format
def load_gif_base64(path):
    with open(path, "rb") as f:
        return base64.b64encode(f.read()).decode("utf-8")

# still_gem_b64 = load_gif_base64("assets/stillGem.gif")
# rotating_gem_b64 = load_gif_base64("assets/rotatingGem.gif")

# Placeholder for GIF HTML
gif_html = st.empty()
caption = st.empty()

# Initially show still gem
# gif_html.markdown(
#     f"<div style='text-align:center;'><img src='data:image/gif;base64,{still_gem_b64}' width='300'></div>",
#     unsafe_allow_html=True,
# )
gif_html.markdown(
    f"<div style='text-align:center;'><img src='https://media0.giphy.com/media/v1.Y2lkPTc5MGI3NjExYTRxYzI2bXJmY3N2bXBtMHJtOGV3NW9vZ3l3M3czbGYybGpkeWQ1YSZlcD12MV9pbnRlcm5hbF9naWZfYnlfaWQmY3Q9cw/3uPWb5EYVvxdfoREQm/giphy.gif' width='300'></div>",
    unsafe_allow_html=True,
)

@st.cache_resource
def load_model():
    # As Gemma is gated, we will show functionality of the demo using DeepSeek-R1-Distill-Qwen-1.5B model 
    # model_id = "google/gemma-2b-it"
    # tokenizer = AutoTokenizer.from_pretrained(model_id, token=True)
    # model_id = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"
    # model_id = "deepseek-ai/deepseek-llm-7b-chat"
    model_id = "deepseek-ai/DeepSeek-V3-0324"

    quantization_config = FineGrainedFP8Config()
    tokenizer = AutoTokenizer.from_pretrained(model_id)
    model = AutoModelForCausalLM.from_pretrained(
        model_id,
        # device_map=None,
        # torch_dtype=torch.float32
        device_map="auto",
        torch_dtype=torch.float16,
        trust_remote_code = True,
        quantization_config=quantization_config
    )
    # model.to("cpu")
    return tokenizer, model

tokenizer, model = load_model()
prompt = st.text_area("Enter your prompt:", "What is Leonardo, the company with the red logo?")
# Example prompt selector
# examples = {
#     "🧠 Summary": "Summarize the history of AI in 5 bullet points.",
#     "💻 Code": "Write a Python function to sort a list using bubble sort.",
#     "📜 Poem": "Write a haiku about large language models.",
#     "🤖 Explain": "Explain what a transformer is in simple terms.",
#     "🔍 Fact": "Who won the FIFA World Cup in 2022?"
# }

# selected_example = st.selectbox("Choose a Gemma to consult:", list(examples.keys()) + ["✍️ Custom input"])
# Add before generation
# col1, col2, col3 = st.columns(3)

# with col1:
#     temperature = st.slider("Temperature", 0.1, 1.5, 1.0)

# with col2:
#     max_tokens = st.slider("Max tokens", 50, 500, 100)

# with col3:
#     top_p = st.slider("Top-p (nucleus sampling)", 0.1, 1.0, 0.95)
# if selected_example != "✍️ Custom input":
#     prompt = examples[selected_example]
# else:
#     prompt = st.text_area("Enter your prompt:")

if st.button("Generate"):
    # Swap to rotating GIF
    # gif_html.markdown(
    #     f"<div style='text-align:center;'><img src='data:image/gif;base64,{rotating_gem_b64}' width='300'></div>",
    #     unsafe_allow_html=True,
    # )
    gif_html.markdown(
        f"<div style='text-align:center;'><img src='https://media2.giphy.com/media/v1.Y2lkPTc5MGI3NjExMXViMm02MnR6bGJ4c2h3ajYzdWNtNXNtYnNic3lnN2xyZzlzbm9seSZlcD12MV9pbnRlcm5hbF9naWZfYnlfaWQmY3Q9cw/k32ddF9WVs44OUaZAm/giphy.gif' width='300'></div>",
        unsafe_allow_html=True,
    )
    caption.markdown("<p style='text-align: center; margin-top: 20px;'>LeoNardo is thinking... 🌀</p>", unsafe_allow_html=True)


    # Generate text

    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
    with torch.no_grad():
        outputs = model.generate(   **inputs,
                                    # max_new_tokens=100, 
                                    max_new_tokens=200, 
                                    temperature=1.0, 
                                    top_p=0.95)

    # Back to still
    # gif_html.markdown(
    #     f"<div style='text-align:center;'><img src='data:image/gif;base64,{still_gem_b64}' width='300'></div>",
    #     unsafe_allow_html=True,
    # )
    gif_html.markdown(
        f"<div style='text-align:center;'><img src='https://media0.giphy.com/media/v1.Y2lkPTc5MGI3NjExYTRxYzI2bXJmY3N2bXBtMHJtOGV3NW9vZ3l3M3czbGYybGpkeWQ1YSZlcD12MV9pbnRlcm5hbF9naWZfYnlfaWQmY3Q9cw/3uPWb5EYVvxdfoREQm/giphy.gif' width='300'></div>",
        unsafe_allow_html=True,
    )
    caption.empty()


    result = tokenizer.decode(outputs[0], skip_special_tokens=True)
    st.markdown("### ✨ Output:")
    st.write(result)