import gradio as gr from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline from peft import PeftModel, PeftConfig import gc import torch import base64 from PIL import Image import io # Global variables to track loaded models and status current_model = None current_pipe = None loading_status = "No model loaded yet" # Model information including base models and whether they support images and streaming MODEL_INFO = { "Chan-Y/gemma-3-1b-reasoning-tr-0128": { "base_model": "google/gemma-3-1b-it", "supports_image": True, "supports_streaming": False }, "Chan-Y/gemma-3-12b-finetune-200steps-0128": { "base_model": "google/gemma-3-12b-it", "supports_image": True, "supports_streaming": False }, "Chan-Y/llama31-8b-turkish-reasoning-300325_merged_16bit": { "base_model": None, # This is a merged model, no base needed "supports_image": False, "supports_streaming": False }, "Chan-Y/qwen25-3b-turkish-reasoning-300325_merged_16bit": { "base_model": None, # This is a merged model, no base needed "supports_image": False, "supports_streaming": False } } def load_adapter_model(model_name): global current_model, current_pipe, loading_status # Update loading status loading_status = f"Loading model: {model_name}..." yield loading_status # If there's a model already loaded, delete it to free memory if current_model is not None: loading_status = f"Unloading previous model to free memory..." yield loading_status del current_model del current_pipe # Force garbage collection gc.collect() torch.cuda.empty_cache() # Get base model info if it's an adapter model model_info = MODEL_INFO.get(model_name, {"base_model": None, "supports_image": False}) base_model_name = model_info["base_model"] # If this is a base model + adapter setup if base_model_name: try: # Update loading status loading_status = f"Loading adapter model {model_name} on top of {base_model_name}..." yield loading_status # Load tokenizer from the base model tokenizer = AutoTokenizer.from_pretrained(base_model_name) print(f"Loading adapter model {model_name} on top of {base_model_name}...") # First load the adapter config loading_status = "Loading adapter configuration..." yield loading_status peft_config = PeftConfig.from_pretrained(model_name) # Then load the base model loading_status = f"Loading base model {base_model_name}..." yield loading_status base_model = AutoModelForCausalLM.from_pretrained( base_model_name, device_map="auto", torch_dtype="auto" ) # Load the adapter on top of the base model loading_status = "Applying adapter to base model..." yield loading_status model = PeftModel.from_pretrained(base_model, model_name) current_model = model # Create pipeline with the loaded model and tokenizer loading_status = "Creating generation pipeline..." yield loading_status pipe = pipeline("text-generation", model=model, tokenizer=tokenizer) current_pipe = pipe loading_status = f"✅ Model {model_name} loaded successfully!" yield loading_status return pipe except Exception as e: loading_status = f"⚠️ PEFT loading failed: {e}" yield loading_status print(f"PEFT loading failed: {e}") # Fall through to try other loading methods # For merged models or if PEFT loading failed try: # Try loading directly if it's already merged or a different format loading_status = f"Trying to load model {model_name} directly..." yield loading_status print(f"Trying to load model {model_name} directly...") # Get tokenizer - if base_model is None, use the model_name itself loading_status = "Loading tokenizer..." yield loading_status tokenizer = AutoTokenizer.from_pretrained(base_model_name or model_name) loading_status = f"Loading model weights for {model_name}..." yield loading_status model = AutoModelForCausalLM.from_pretrained( model_name, device_map="auto", torch_dtype="auto" ) current_model = model # Create pipeline with the loaded model and tokenizer loading_status = "Creating generation pipeline..." yield loading_status pipe = pipeline("text-generation", model=model, tokenizer=tokenizer) current_pipe = pipe loading_status = f"✅ Model {model_name} loaded successfully!" yield loading_status return pipe except Exception as e2: loading_status = f"⚠️ Direct loading failed: {e2}" yield loading_status print(f"Direct loading failed: {e2}") # Fallback to using the model name in pipeline loading_status = "Falling back to using the model name in pipeline..." yield loading_status print("Falling back to using the model name in pipeline...") pipe = pipeline("text-generation", model=model_name) current_pipe = pipe loading_status = f"✅ Model {model_name} loaded (fallback method)" yield loading_status return pipe def encode_image_to_base64(image_path): """Convert image to base64 string for model input""" with open(image_path, "rb") as image_file: return base64.b64encode(image_file.read()).decode("utf-8") def generate_response(model_name, prompt, system_prompt, image, max_length, temperature, top_p, top_k, stream=True): """Generate text using the model based on user input and advanced settings""" global pipe # Get model info model_info = MODEL_INFO.get(model_name, {}) supports_image = model_info.get("supports_image", False) supports_streaming = model_info.get("supports_streaming", False) # Check if streaming is requested but not supported if stream and not supports_streaming: stream = False # Disable streaming for models that don't support it # Check if we need to load a model (if none is loaded) or a different model if pipe is None or model_name != getattr(pipe, 'model_name', None): # Load model and update status during loading loading_generator = load_adapter_model(model_name) status_updates = [] # Collect all status updates for status in loading_generator: status_updates.append(status) yield status, gr.update(interactive=False) # Disable button during loading # Get the pipeline from the last status pipe = current_pipe pipe.model_name = model_name # Check if model supports images and an image is provided supports_image = MODEL_INFO.get(model_name, {}).get("supports_image", False) # Format messages for the model if supports_image and image is not None: # Convert image to base64 string if isinstance(image, str): # Path to image img = Image.open(image) buffered = io.BytesIO() img.save(buffered, format="JPEG") img_str = base64.b64encode(buffered.getvalue()).decode() else: # Already an image object from gradio buffered = io.BytesIO() image.save(buffered, format="JPEG") img_str = base64.b64encode(buffered.getvalue()).decode() # Create message with image messages = [ [ { "role": "system", "content": [{"type": "text", "text": system_prompt}] }, { "role": "user", "content": [ {"type": "text", "text": prompt}, {"type": "image", "image": f"data:image/jpeg;base64,{img_str}"} ] }, ], ] else: # Text-only message messages = [ [ { "role": "system", "content": [{"type": "text", "text": system_prompt}] }, { "role": "user", "content": [{"type": "text", "text": prompt}] }, ], ] # Re-enable submit button yield "Generating response...", gr.update(interactive=True) # Generate based on whether streaming is supported if not stream or not supports_streaming: # Generate text without streaming try: # For models without streaming support generation_args = { "max_new_tokens": max_length, "temperature": temperature, "top_p": top_p, "top_k": top_k } output = pipe(messages, **generation_args) # Extract the generated text from the output generated_text = output[0][0]["generated_text"][-1]["content"] yield generated_text, gr.update(interactive=True) except Exception as e: error_msg = f"Error generating response: {str(e)}" yield error_msg, gr.update(interactive=True) else: # For streaming mode try: text = "" # Stream generation parameters stream_args = { "max_new_tokens": max_length, "temperature": temperature, "top_p": top_p, "top_k": top_k, "stream": True } # Generate with streaming for response in pipe(messages, **stream_args): # Extract the generated chunk if len(response) > 0 and len(response[0]) > 0 and "generated_text" in response[0][0]: # Get the last message content last_message = response[0][0]["generated_text"][-1] if "content" in last_message: # Update the text with the new content text = last_message["content"] yield text, gr.update(interactive=True) except Exception as e: error_msg = f"Error during streaming generation: {str(e)}" yield error_msg, gr.update(interactive=True) # Default model name default_model = "Chan-Y/gemma-3-12b-finetune-200steps-0128" # Initialize pipeline as None - we'll load the model before launching the interface pipe = None # Variable to store initial status for the status indicator initial_status = "Loading default model. Please wait..." # Default system prompt in Turkish default_system_prompt = """Sana bir problem verildi. Problem hakkında düşün ve çalışmanı göster. Çalışmanı ve arasına yerleştir. Sonra, çözümünü ve arasına yerleştir. Lütfen SADECE Türkçe kullan.""" # Create Gradio interface with gr.Blocks() as demo: gr.Markdown("# Gemma 3 Reasoning Model Interface") gr.Markdown("Using Gemma 3 with Turkish reasoning adapters") # Add status indicator at the top status_indicator = gr.Textbox( value=initial_status, label="Status", interactive=False ) with gr.Row(): with gr.Column(): prompt_input = gr.Textbox( lines=5, placeholder="Enter your prompt here...", label="Prompt" ) # Image input (only shown for Gemma models) image_input = gr.Image( type="pil", label="Upload Image (Only works with Gemma models)", visible=True # Initially visible since default is Gemma ) # Advanced settings in an expander (accordion) with gr.Accordion("Advanced Settings", open=False): # Move model selection here model_selector = gr.Dropdown( choices=[ "Chan-Y/gemma-3-12b-finetune-200steps-0128", # Default first "Chan-Y/gemma-3-reasoning-tr-0.2", "Chan-Y/gemma-3-1b-reasoning-tr-0128", "Chan-Y/llama31-8b-turkish-reasoning-300325_merged_16bit", "Chan-Y/qwen25-3b-turkish-reasoning-300325_merged_16bit" ], value=default_model, label="Select Model", info="Choosing a new model will unload the current one to save memory" ) system_prompt = gr.Textbox( lines=5, value=default_system_prompt, label="System Prompt" ) temperature = gr.Slider( minimum=0.0, maximum=1.0, value=0.75, step=0.1, label="Temperature" ) max_tokens = gr.Slider( minimum=16, maximum=1024, value=512, step=10, label="Max New Tokens" ) top_p_value = gr.Slider( minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p" ) top_k_value = gr.Slider( minimum=1, maximum=100, value=64, step=1, label="Top-k" ) use_streaming = gr.Checkbox( value=True, label="Use Streaming Response" ) submit_btn = gr.Button("Generate Response") with gr.Column(): output_text = gr.Textbox(lines=15, label="Generated Response") # Function to update image input visibility based on model selection def update_image_visibility(model_name): model_info = MODEL_INFO.get(model_name, {}) is_gemma = model_info.get("supports_image", False) supports_streaming = model_info.get("supports_streaming", False) # Update checkbox based on streaming support streaming_status = "" if supports_streaming else " (Not supported by this model)" return { "visible": is_gemma # For the image input }, f"Use Streaming Response{streaming_status}" # For the streaming checkbox # Function to show model info when selected def update_model_info(model_name): model_info = MODEL_INFO.get(model_name, {}) base_model = model_info.get("base_model", "None (standalone model)") img_support = "Yes" if model_info.get("supports_image", False) else "No" stream_support = "Yes" if model_info.get("supports_streaming", False) else "No" return f"Selected model: {model_name}\nBase model: {base_model}\nImage support: {img_support}\nStreaming support: {stream_support}" # Connect both functions to update image visibility and show model info model_selector.change( fn=update_image_visibility, inputs=[model_selector], outputs=[image_input, use_streaming] ) model_selector.change( fn=update_model_info, inputs=[model_selector], outputs=[status_indicator] ) # Initialize interface components def initialize_interface(): # Check if model is loaded and update status if pipe is not None: model_info = MODEL_INFO.get(default_model, {}) base_model = model_info.get("base_model", "None (standalone model)") img_support = "Yes" if model_info.get("supports_image", False) else "No" status = f"Ready with model: {default_model}\nBase model: {base_model}\nImage support: {img_support}" submit_interactive = True else: status = "Model not loaded yet. Please wait..." submit_interactive = False return [ status, gr.update(interactive=submit_interactive) ] # Use this at launch time demo.load( fn=initialize_interface, outputs=[status_indicator, submit_btn] ) # Connect the generation function to the interface submit_btn.click( fn=generate_response, inputs=[ model_selector, prompt_input, system_prompt, image_input, max_tokens, temperature, top_p_value, top_k_value, use_streaming ], outputs=[ status_indicator, output_text ], show_progress=True ) # Launch the interface if __name__ == "__main__": # Load the default model before launching Gradio print(f"Preloading default model: {default_model}") # Load the model for status in load_adapter_model(default_model): print(status) loading_status = status # Update the global loading status # Set the loaded model name if current_pipe: current_pipe.model_name = default_model pipe = current_pipe # Update initial status for the UI initial_status = f"✅ Default model {default_model} loaded and ready to use!" # Launch Gradio interface demo.queue() # Enable queuing for better handling of multiple requests demo.launch(share=False) # Set share=True to create a public link