gemma3-reasoning / app2.py
Chan-Y's picture
Rename app.py to app2.py
12c501c verified
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
from peft import PeftModel, PeftConfig
import gc
import torch
import base64
from PIL import Image
import io
# Global variables to track loaded models and status
current_model = None
current_pipe = None
loading_status = "No model loaded yet"
# Model information including base models and whether they support images and streaming
MODEL_INFO = {
"Chan-Y/gemma-3-1b-reasoning-tr-0128": {
"base_model": "google/gemma-3-1b-it",
"supports_image": True,
"supports_streaming": False
},
"Chan-Y/gemma-3-12b-finetune-200steps-0128": {
"base_model": "google/gemma-3-12b-it",
"supports_image": True,
"supports_streaming": False
},
"Chan-Y/llama31-8b-turkish-reasoning-300325_merged_16bit": {
"base_model": None, # This is a merged model, no base needed
"supports_image": False,
"supports_streaming": False
},
"Chan-Y/qwen25-3b-turkish-reasoning-300325_merged_16bit": {
"base_model": None, # This is a merged model, no base needed
"supports_image": False,
"supports_streaming": False
}
}
def load_adapter_model(model_name):
global current_model, current_pipe, loading_status
# Update loading status
loading_status = f"Loading model: {model_name}..."
yield loading_status
# If there's a model already loaded, delete it to free memory
if current_model is not None:
loading_status = f"Unloading previous model to free memory..."
yield loading_status
del current_model
del current_pipe
# Force garbage collection
gc.collect()
torch.cuda.empty_cache()
# Get base model info if it's an adapter model
model_info = MODEL_INFO.get(model_name, {"base_model": None, "supports_image": False})
base_model_name = model_info["base_model"]
# If this is a base model + adapter setup
if base_model_name:
try:
# Update loading status
loading_status = f"Loading adapter model {model_name} on top of {base_model_name}..."
yield loading_status
# Load tokenizer from the base model
tokenizer = AutoTokenizer.from_pretrained(base_model_name)
print(f"Loading adapter model {model_name} on top of {base_model_name}...")
# First load the adapter config
loading_status = "Loading adapter configuration..."
yield loading_status
peft_config = PeftConfig.from_pretrained(model_name)
# Then load the base model
loading_status = f"Loading base model {base_model_name}..."
yield loading_status
base_model = AutoModelForCausalLM.from_pretrained(
base_model_name,
device_map="auto",
torch_dtype="auto"
)
# Load the adapter on top of the base model
loading_status = "Applying adapter to base model..."
yield loading_status
model = PeftModel.from_pretrained(base_model, model_name)
current_model = model
# Create pipeline with the loaded model and tokenizer
loading_status = "Creating generation pipeline..."
yield loading_status
pipe = pipeline("text-generation", model=model, tokenizer=tokenizer)
current_pipe = pipe
loading_status = f"✅ Model {model_name} loaded successfully!"
yield loading_status
return pipe
except Exception as e:
loading_status = f"⚠️ PEFT loading failed: {e}"
yield loading_status
print(f"PEFT loading failed: {e}")
# Fall through to try other loading methods
# For merged models or if PEFT loading failed
try:
# Try loading directly if it's already merged or a different format
loading_status = f"Trying to load model {model_name} directly..."
yield loading_status
print(f"Trying to load model {model_name} directly...")
# Get tokenizer - if base_model is None, use the model_name itself
loading_status = "Loading tokenizer..."
yield loading_status
tokenizer = AutoTokenizer.from_pretrained(base_model_name or model_name)
loading_status = f"Loading model weights for {model_name}..."
yield loading_status
model = AutoModelForCausalLM.from_pretrained(
model_name,
device_map="auto",
torch_dtype="auto"
)
current_model = model
# Create pipeline with the loaded model and tokenizer
loading_status = "Creating generation pipeline..."
yield loading_status
pipe = pipeline("text-generation", model=model, tokenizer=tokenizer)
current_pipe = pipe
loading_status = f"✅ Model {model_name} loaded successfully!"
yield loading_status
return pipe
except Exception as e2:
loading_status = f"⚠️ Direct loading failed: {e2}"
yield loading_status
print(f"Direct loading failed: {e2}")
# Fallback to using the model name in pipeline
loading_status = "Falling back to using the model name in pipeline..."
yield loading_status
print("Falling back to using the model name in pipeline...")
pipe = pipeline("text-generation", model=model_name)
current_pipe = pipe
loading_status = f"✅ Model {model_name} loaded (fallback method)"
yield loading_status
return pipe
def encode_image_to_base64(image_path):
"""Convert image to base64 string for model input"""
with open(image_path, "rb") as image_file:
return base64.b64encode(image_file.read()).decode("utf-8")
def generate_response(model_name, prompt, system_prompt, image, max_length, temperature, top_p, top_k, stream=True):
"""Generate text using the model based on user input and advanced settings"""
global pipe
# Get model info
model_info = MODEL_INFO.get(model_name, {})
supports_image = model_info.get("supports_image", False)
supports_streaming = model_info.get("supports_streaming", False)
# Check if streaming is requested but not supported
if stream and not supports_streaming:
stream = False # Disable streaming for models that don't support it
# Check if we need to load a model (if none is loaded) or a different model
if pipe is None or model_name != getattr(pipe, 'model_name', None):
# Load model and update status during loading
loading_generator = load_adapter_model(model_name)
status_updates = []
# Collect all status updates
for status in loading_generator:
status_updates.append(status)
yield status, gr.update(interactive=False) # Disable button during loading
# Get the pipeline from the last status
pipe = current_pipe
pipe.model_name = model_name
# Check if model supports images and an image is provided
supports_image = MODEL_INFO.get(model_name, {}).get("supports_image", False)
# Format messages for the model
if supports_image and image is not None:
# Convert image to base64 string
if isinstance(image, str): # Path to image
img = Image.open(image)
buffered = io.BytesIO()
img.save(buffered, format="JPEG")
img_str = base64.b64encode(buffered.getvalue()).decode()
else: # Already an image object from gradio
buffered = io.BytesIO()
image.save(buffered, format="JPEG")
img_str = base64.b64encode(buffered.getvalue()).decode()
# Create message with image
messages = [
[
{
"role": "system",
"content": [{"type": "text", "text": system_prompt}]
},
{
"role": "user",
"content": [
{"type": "text", "text": prompt},
{"type": "image", "image": f"data:image/jpeg;base64,{img_str}"}
]
},
],
]
else:
# Text-only message
messages = [
[
{
"role": "system",
"content": [{"type": "text", "text": system_prompt}]
},
{
"role": "user",
"content": [{"type": "text", "text": prompt}]
},
],
]
# Re-enable submit button
yield "Generating response...", gr.update(interactive=True)
# Generate based on whether streaming is supported
if not stream or not supports_streaming:
# Generate text without streaming
try:
# For models without streaming support
generation_args = {
"max_new_tokens": max_length,
"temperature": temperature,
"top_p": top_p,
"top_k": top_k
}
output = pipe(messages, **generation_args)
# Extract the generated text from the output
generated_text = output[0][0]["generated_text"][-1]["content"]
yield generated_text, gr.update(interactive=True)
except Exception as e:
error_msg = f"Error generating response: {str(e)}"
yield error_msg, gr.update(interactive=True)
else:
# For streaming mode
try:
text = ""
# Stream generation parameters
stream_args = {
"max_new_tokens": max_length,
"temperature": temperature,
"top_p": top_p,
"top_k": top_k,
"stream": True
}
# Generate with streaming
for response in pipe(messages, **stream_args):
# Extract the generated chunk
if len(response) > 0 and len(response[0]) > 0 and "generated_text" in response[0][0]:
# Get the last message content
last_message = response[0][0]["generated_text"][-1]
if "content" in last_message:
# Update the text with the new content
text = last_message["content"]
yield text, gr.update(interactive=True)
except Exception as e:
error_msg = f"Error during streaming generation: {str(e)}"
yield error_msg, gr.update(interactive=True)
# Default model name
default_model = "Chan-Y/gemma-3-12b-finetune-200steps-0128"
# Initialize pipeline as None - we'll load the model before launching the interface
pipe = None
# Variable to store initial status for the status indicator
initial_status = "Loading default model. Please wait..."
# Default system prompt in Turkish
default_system_prompt = """Sana bir problem verildi.
Problem hakkında düşün ve çalışmanı göster.
Çalışmanı <start_working_out> ve <end_working_out> arasına yerleştir.
Sonra, çözümünü <SOLUTION> ve </SOLUTION> arasına yerleştir.
Lütfen SADECE Türkçe kullan."""
# Create Gradio interface
with gr.Blocks() as demo:
gr.Markdown("# Gemma 3 Reasoning Model Interface")
gr.Markdown("Using Gemma 3 with Turkish reasoning adapters")
# Add status indicator at the top
status_indicator = gr.Textbox(
value=initial_status,
label="Status",
interactive=False
)
with gr.Row():
with gr.Column():
prompt_input = gr.Textbox(
lines=5,
placeholder="Enter your prompt here...",
label="Prompt"
)
# Image input (only shown for Gemma models)
image_input = gr.Image(
type="pil",
label="Upload Image (Only works with Gemma models)",
visible=True # Initially visible since default is Gemma
)
# Advanced settings in an expander (accordion)
with gr.Accordion("Advanced Settings", open=False):
# Move model selection here
model_selector = gr.Dropdown(
choices=[
"Chan-Y/gemma-3-12b-finetune-200steps-0128", # Default first
"Chan-Y/gemma-3-reasoning-tr-0.2",
"Chan-Y/gemma-3-1b-reasoning-tr-0128",
"Chan-Y/llama31-8b-turkish-reasoning-300325_merged_16bit",
"Chan-Y/qwen25-3b-turkish-reasoning-300325_merged_16bit"
],
value=default_model,
label="Select Model",
info="Choosing a new model will unload the current one to save memory"
)
system_prompt = gr.Textbox(
lines=5,
value=default_system_prompt,
label="System Prompt"
)
temperature = gr.Slider(
minimum=0.0,
maximum=1.0,
value=0.75,
step=0.1,
label="Temperature"
)
max_tokens = gr.Slider(
minimum=16,
maximum=1024,
value=512,
step=10,
label="Max New Tokens"
)
top_p_value = gr.Slider(
minimum=0.1,
maximum=1.0,
value=0.95,
step=0.05,
label="Top-p"
)
top_k_value = gr.Slider(
minimum=1,
maximum=100,
value=64,
step=1,
label="Top-k"
)
use_streaming = gr.Checkbox(
value=True,
label="Use Streaming Response"
)
submit_btn = gr.Button("Generate Response")
with gr.Column():
output_text = gr.Textbox(lines=15, label="Generated Response")
# Function to update image input visibility based on model selection
def update_image_visibility(model_name):
model_info = MODEL_INFO.get(model_name, {})
is_gemma = model_info.get("supports_image", False)
supports_streaming = model_info.get("supports_streaming", False)
# Update checkbox based on streaming support
streaming_status = "" if supports_streaming else " (Not supported by this model)"
return {
"visible": is_gemma # For the image input
}, f"Use Streaming Response{streaming_status}" # For the streaming checkbox
# Function to show model info when selected
def update_model_info(model_name):
model_info = MODEL_INFO.get(model_name, {})
base_model = model_info.get("base_model", "None (standalone model)")
img_support = "Yes" if model_info.get("supports_image", False) else "No"
stream_support = "Yes" if model_info.get("supports_streaming", False) else "No"
return f"Selected model: {model_name}\nBase model: {base_model}\nImage support: {img_support}\nStreaming support: {stream_support}"
# Connect both functions to update image visibility and show model info
model_selector.change(
fn=update_image_visibility,
inputs=[model_selector],
outputs=[image_input, use_streaming]
)
model_selector.change(
fn=update_model_info,
inputs=[model_selector],
outputs=[status_indicator]
)
# Initialize interface components
def initialize_interface():
# Check if model is loaded and update status
if pipe is not None:
model_info = MODEL_INFO.get(default_model, {})
base_model = model_info.get("base_model", "None (standalone model)")
img_support = "Yes" if model_info.get("supports_image", False) else "No"
status = f"Ready with model: {default_model}\nBase model: {base_model}\nImage support: {img_support}"
submit_interactive = True
else:
status = "Model not loaded yet. Please wait..."
submit_interactive = False
return [
status,
gr.update(interactive=submit_interactive)
]
# Use this at launch time
demo.load(
fn=initialize_interface,
outputs=[status_indicator, submit_btn]
)
# Connect the generation function to the interface
submit_btn.click(
fn=generate_response,
inputs=[
model_selector,
prompt_input,
system_prompt,
image_input,
max_tokens,
temperature,
top_p_value,
top_k_value,
use_streaming
],
outputs=[
status_indicator,
output_text
],
show_progress=True
)
# Launch the interface
if __name__ == "__main__":
# Load the default model before launching Gradio
print(f"Preloading default model: {default_model}")
# Load the model
for status in load_adapter_model(default_model):
print(status)
loading_status = status # Update the global loading status
# Set the loaded model name
if current_pipe:
current_pipe.model_name = default_model
pipe = current_pipe
# Update initial status for the UI
initial_status = f"✅ Default model {default_model} loaded and ready to use!"
# Launch Gradio interface
demo.queue() # Enable queuing for better handling of multiple requests
demo.launch(share=False) # Set share=True to create a public link