Spaces:
Sleeping
Sleeping
import gradio as gr | |
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline | |
from peft import PeftModel, PeftConfig | |
import gc | |
import torch | |
import base64 | |
from PIL import Image | |
import io | |
# Global variables to track loaded models and status | |
current_model = None | |
current_pipe = None | |
loading_status = "No model loaded yet" | |
# Model information including base models and whether they support images and streaming | |
MODEL_INFO = { | |
"Chan-Y/gemma-3-1b-reasoning-tr-0128": { | |
"base_model": "google/gemma-3-1b-it", | |
"supports_image": True, | |
"supports_streaming": False | |
}, | |
"Chan-Y/gemma-3-12b-finetune-200steps-0128": { | |
"base_model": "google/gemma-3-12b-it", | |
"supports_image": True, | |
"supports_streaming": False | |
}, | |
"Chan-Y/llama31-8b-turkish-reasoning-300325_merged_16bit": { | |
"base_model": None, # This is a merged model, no base needed | |
"supports_image": False, | |
"supports_streaming": False | |
}, | |
"Chan-Y/qwen25-3b-turkish-reasoning-300325_merged_16bit": { | |
"base_model": None, # This is a merged model, no base needed | |
"supports_image": False, | |
"supports_streaming": False | |
} | |
} | |
def load_adapter_model(model_name): | |
global current_model, current_pipe, loading_status | |
# Update loading status | |
loading_status = f"Loading model: {model_name}..." | |
yield loading_status | |
# If there's a model already loaded, delete it to free memory | |
if current_model is not None: | |
loading_status = f"Unloading previous model to free memory..." | |
yield loading_status | |
del current_model | |
del current_pipe | |
# Force garbage collection | |
gc.collect() | |
torch.cuda.empty_cache() | |
# Get base model info if it's an adapter model | |
model_info = MODEL_INFO.get(model_name, {"base_model": None, "supports_image": False}) | |
base_model_name = model_info["base_model"] | |
# If this is a base model + adapter setup | |
if base_model_name: | |
try: | |
# Update loading status | |
loading_status = f"Loading adapter model {model_name} on top of {base_model_name}..." | |
yield loading_status | |
# Load tokenizer from the base model | |
tokenizer = AutoTokenizer.from_pretrained(base_model_name) | |
print(f"Loading adapter model {model_name} on top of {base_model_name}...") | |
# First load the adapter config | |
loading_status = "Loading adapter configuration..." | |
yield loading_status | |
peft_config = PeftConfig.from_pretrained(model_name) | |
# Then load the base model | |
loading_status = f"Loading base model {base_model_name}..." | |
yield loading_status | |
base_model = AutoModelForCausalLM.from_pretrained( | |
base_model_name, | |
device_map="auto", | |
torch_dtype="auto" | |
) | |
# Load the adapter on top of the base model | |
loading_status = "Applying adapter to base model..." | |
yield loading_status | |
model = PeftModel.from_pretrained(base_model, model_name) | |
current_model = model | |
# Create pipeline with the loaded model and tokenizer | |
loading_status = "Creating generation pipeline..." | |
yield loading_status | |
pipe = pipeline("text-generation", model=model, tokenizer=tokenizer) | |
current_pipe = pipe | |
loading_status = f"✅ Model {model_name} loaded successfully!" | |
yield loading_status | |
return pipe | |
except Exception as e: | |
loading_status = f"⚠️ PEFT loading failed: {e}" | |
yield loading_status | |
print(f"PEFT loading failed: {e}") | |
# Fall through to try other loading methods | |
# For merged models or if PEFT loading failed | |
try: | |
# Try loading directly if it's already merged or a different format | |
loading_status = f"Trying to load model {model_name} directly..." | |
yield loading_status | |
print(f"Trying to load model {model_name} directly...") | |
# Get tokenizer - if base_model is None, use the model_name itself | |
loading_status = "Loading tokenizer..." | |
yield loading_status | |
tokenizer = AutoTokenizer.from_pretrained(base_model_name or model_name) | |
loading_status = f"Loading model weights for {model_name}..." | |
yield loading_status | |
model = AutoModelForCausalLM.from_pretrained( | |
model_name, | |
device_map="auto", | |
torch_dtype="auto" | |
) | |
current_model = model | |
# Create pipeline with the loaded model and tokenizer | |
loading_status = "Creating generation pipeline..." | |
yield loading_status | |
pipe = pipeline("text-generation", model=model, tokenizer=tokenizer) | |
current_pipe = pipe | |
loading_status = f"✅ Model {model_name} loaded successfully!" | |
yield loading_status | |
return pipe | |
except Exception as e2: | |
loading_status = f"⚠️ Direct loading failed: {e2}" | |
yield loading_status | |
print(f"Direct loading failed: {e2}") | |
# Fallback to using the model name in pipeline | |
loading_status = "Falling back to using the model name in pipeline..." | |
yield loading_status | |
print("Falling back to using the model name in pipeline...") | |
pipe = pipeline("text-generation", model=model_name) | |
current_pipe = pipe | |
loading_status = f"✅ Model {model_name} loaded (fallback method)" | |
yield loading_status | |
return pipe | |
def encode_image_to_base64(image_path): | |
"""Convert image to base64 string for model input""" | |
with open(image_path, "rb") as image_file: | |
return base64.b64encode(image_file.read()).decode("utf-8") | |
def generate_response(model_name, prompt, system_prompt, image, max_length, temperature, top_p, top_k, stream=True): | |
"""Generate text using the model based on user input and advanced settings""" | |
global pipe | |
# Get model info | |
model_info = MODEL_INFO.get(model_name, {}) | |
supports_image = model_info.get("supports_image", False) | |
supports_streaming = model_info.get("supports_streaming", False) | |
# Check if streaming is requested but not supported | |
if stream and not supports_streaming: | |
stream = False # Disable streaming for models that don't support it | |
# Check if we need to load a model (if none is loaded) or a different model | |
if pipe is None or model_name != getattr(pipe, 'model_name', None): | |
# Load model and update status during loading | |
loading_generator = load_adapter_model(model_name) | |
status_updates = [] | |
# Collect all status updates | |
for status in loading_generator: | |
status_updates.append(status) | |
yield status, gr.update(interactive=False) # Disable button during loading | |
# Get the pipeline from the last status | |
pipe = current_pipe | |
pipe.model_name = model_name | |
# Check if model supports images and an image is provided | |
supports_image = MODEL_INFO.get(model_name, {}).get("supports_image", False) | |
# Format messages for the model | |
if supports_image and image is not None: | |
# Convert image to base64 string | |
if isinstance(image, str): # Path to image | |
img = Image.open(image) | |
buffered = io.BytesIO() | |
img.save(buffered, format="JPEG") | |
img_str = base64.b64encode(buffered.getvalue()).decode() | |
else: # Already an image object from gradio | |
buffered = io.BytesIO() | |
image.save(buffered, format="JPEG") | |
img_str = base64.b64encode(buffered.getvalue()).decode() | |
# Create message with image | |
messages = [ | |
[ | |
{ | |
"role": "system", | |
"content": [{"type": "text", "text": system_prompt}] | |
}, | |
{ | |
"role": "user", | |
"content": [ | |
{"type": "text", "text": prompt}, | |
{"type": "image", "image": f"data:image/jpeg;base64,{img_str}"} | |
] | |
}, | |
], | |
] | |
else: | |
# Text-only message | |
messages = [ | |
[ | |
{ | |
"role": "system", | |
"content": [{"type": "text", "text": system_prompt}] | |
}, | |
{ | |
"role": "user", | |
"content": [{"type": "text", "text": prompt}] | |
}, | |
], | |
] | |
# Re-enable submit button | |
yield "Generating response...", gr.update(interactive=True) | |
# Generate based on whether streaming is supported | |
if not stream or not supports_streaming: | |
# Generate text without streaming | |
try: | |
# For models without streaming support | |
generation_args = { | |
"max_new_tokens": max_length, | |
"temperature": temperature, | |
"top_p": top_p, | |
"top_k": top_k | |
} | |
output = pipe(messages, **generation_args) | |
# Extract the generated text from the output | |
generated_text = output[0][0]["generated_text"][-1]["content"] | |
yield generated_text, gr.update(interactive=True) | |
except Exception as e: | |
error_msg = f"Error generating response: {str(e)}" | |
yield error_msg, gr.update(interactive=True) | |
else: | |
# For streaming mode | |
try: | |
text = "" | |
# Stream generation parameters | |
stream_args = { | |
"max_new_tokens": max_length, | |
"temperature": temperature, | |
"top_p": top_p, | |
"top_k": top_k, | |
"stream": True | |
} | |
# Generate with streaming | |
for response in pipe(messages, **stream_args): | |
# Extract the generated chunk | |
if len(response) > 0 and len(response[0]) > 0 and "generated_text" in response[0][0]: | |
# Get the last message content | |
last_message = response[0][0]["generated_text"][-1] | |
if "content" in last_message: | |
# Update the text with the new content | |
text = last_message["content"] | |
yield text, gr.update(interactive=True) | |
except Exception as e: | |
error_msg = f"Error during streaming generation: {str(e)}" | |
yield error_msg, gr.update(interactive=True) | |
# Default model name | |
default_model = "Chan-Y/gemma-3-12b-finetune-200steps-0128" | |
# Initialize pipeline as None - we'll load the model before launching the interface | |
pipe = None | |
# Variable to store initial status for the status indicator | |
initial_status = "Loading default model. Please wait..." | |
# Default system prompt in Turkish | |
default_system_prompt = """Sana bir problem verildi. | |
Problem hakkında düşün ve çalışmanı göster. | |
Çalışmanı <start_working_out> ve <end_working_out> arasına yerleştir. | |
Sonra, çözümünü <SOLUTION> ve </SOLUTION> arasına yerleştir. | |
Lütfen SADECE Türkçe kullan.""" | |
# Create Gradio interface | |
with gr.Blocks() as demo: | |
gr.Markdown("# Gemma 3 Reasoning Model Interface") | |
gr.Markdown("Using Gemma 3 with Turkish reasoning adapters") | |
# Add status indicator at the top | |
status_indicator = gr.Textbox( | |
value=initial_status, | |
label="Status", | |
interactive=False | |
) | |
with gr.Row(): | |
with gr.Column(): | |
prompt_input = gr.Textbox( | |
lines=5, | |
placeholder="Enter your prompt here...", | |
label="Prompt" | |
) | |
# Image input (only shown for Gemma models) | |
image_input = gr.Image( | |
type="pil", | |
label="Upload Image (Only works with Gemma models)", | |
visible=True # Initially visible since default is Gemma | |
) | |
# Advanced settings in an expander (accordion) | |
with gr.Accordion("Advanced Settings", open=False): | |
# Move model selection here | |
model_selector = gr.Dropdown( | |
choices=[ | |
"Chan-Y/gemma-3-12b-finetune-200steps-0128", # Default first | |
"Chan-Y/gemma-3-reasoning-tr-0.2", | |
"Chan-Y/gemma-3-1b-reasoning-tr-0128", | |
"Chan-Y/llama31-8b-turkish-reasoning-300325_merged_16bit", | |
"Chan-Y/qwen25-3b-turkish-reasoning-300325_merged_16bit" | |
], | |
value=default_model, | |
label="Select Model", | |
info="Choosing a new model will unload the current one to save memory" | |
) | |
system_prompt = gr.Textbox( | |
lines=5, | |
value=default_system_prompt, | |
label="System Prompt" | |
) | |
temperature = gr.Slider( | |
minimum=0.0, | |
maximum=1.0, | |
value=0.75, | |
step=0.1, | |
label="Temperature" | |
) | |
max_tokens = gr.Slider( | |
minimum=16, | |
maximum=1024, | |
value=512, | |
step=10, | |
label="Max New Tokens" | |
) | |
top_p_value = gr.Slider( | |
minimum=0.1, | |
maximum=1.0, | |
value=0.95, | |
step=0.05, | |
label="Top-p" | |
) | |
top_k_value = gr.Slider( | |
minimum=1, | |
maximum=100, | |
value=64, | |
step=1, | |
label="Top-k" | |
) | |
use_streaming = gr.Checkbox( | |
value=True, | |
label="Use Streaming Response" | |
) | |
submit_btn = gr.Button("Generate Response") | |
with gr.Column(): | |
output_text = gr.Textbox(lines=15, label="Generated Response") | |
# Function to update image input visibility based on model selection | |
def update_image_visibility(model_name): | |
model_info = MODEL_INFO.get(model_name, {}) | |
is_gemma = model_info.get("supports_image", False) | |
supports_streaming = model_info.get("supports_streaming", False) | |
# Update checkbox based on streaming support | |
streaming_status = "" if supports_streaming else " (Not supported by this model)" | |
return { | |
"visible": is_gemma # For the image input | |
}, f"Use Streaming Response{streaming_status}" # For the streaming checkbox | |
# Function to show model info when selected | |
def update_model_info(model_name): | |
model_info = MODEL_INFO.get(model_name, {}) | |
base_model = model_info.get("base_model", "None (standalone model)") | |
img_support = "Yes" if model_info.get("supports_image", False) else "No" | |
stream_support = "Yes" if model_info.get("supports_streaming", False) else "No" | |
return f"Selected model: {model_name}\nBase model: {base_model}\nImage support: {img_support}\nStreaming support: {stream_support}" | |
# Connect both functions to update image visibility and show model info | |
model_selector.change( | |
fn=update_image_visibility, | |
inputs=[model_selector], | |
outputs=[image_input, use_streaming] | |
) | |
model_selector.change( | |
fn=update_model_info, | |
inputs=[model_selector], | |
outputs=[status_indicator] | |
) | |
# Initialize interface components | |
def initialize_interface(): | |
# Check if model is loaded and update status | |
if pipe is not None: | |
model_info = MODEL_INFO.get(default_model, {}) | |
base_model = model_info.get("base_model", "None (standalone model)") | |
img_support = "Yes" if model_info.get("supports_image", False) else "No" | |
status = f"Ready with model: {default_model}\nBase model: {base_model}\nImage support: {img_support}" | |
submit_interactive = True | |
else: | |
status = "Model not loaded yet. Please wait..." | |
submit_interactive = False | |
return [ | |
status, | |
gr.update(interactive=submit_interactive) | |
] | |
# Use this at launch time | |
demo.load( | |
fn=initialize_interface, | |
outputs=[status_indicator, submit_btn] | |
) | |
# Connect the generation function to the interface | |
submit_btn.click( | |
fn=generate_response, | |
inputs=[ | |
model_selector, | |
prompt_input, | |
system_prompt, | |
image_input, | |
max_tokens, | |
temperature, | |
top_p_value, | |
top_k_value, | |
use_streaming | |
], | |
outputs=[ | |
status_indicator, | |
output_text | |
], | |
show_progress=True | |
) | |
# Launch the interface | |
if __name__ == "__main__": | |
# Load the default model before launching Gradio | |
print(f"Preloading default model: {default_model}") | |
# Load the model | |
for status in load_adapter_model(default_model): | |
print(status) | |
loading_status = status # Update the global loading status | |
# Set the loaded model name | |
if current_pipe: | |
current_pipe.model_name = default_model | |
pipe = current_pipe | |
# Update initial status for the UI | |
initial_status = f"✅ Default model {default_model} loaded and ready to use!" | |
# Launch Gradio interface | |
demo.queue() # Enable queuing for better handling of multiple requests | |
demo.launch(share=False) # Set share=True to create a public link |