import gradio as gr from transformers import pipeline from arabert.aragpt2.grover.modeling_gpt2 import GPT2LMHeadModel from transformers import AutoTokenizer import re import torch import spaces # Import the spaces module for ZeroGPU model_name = "Naseej/AskMe-Large" tokenizer = AutoTokenizer.from_pretrained(model_name, bos_token='<|startoftext|>', eos_token='<|endoftext|>', pad_token='<|pad|>') model = GPT2LMHeadModel.from_pretrained(model_name) model.resize_token_embeddings(len(tokenizer)) # For ZeroGPU, we'll move the model to CUDA inside the decorated function # Create the generator pipeline without specifying device generator = pipeline("text-generation", model=model, tokenizer=tokenizer) # ZeroGPU-decorated function for text generation @spaces.GPU(duration=60) # Set duration based on your needs def generate_response(message, history, num_beams=4, temperature=0.99, do_sample=True, top_k=60, top_p=0.9): # Move model to CUDA inside the decorated function generator.model = generator.model.to('cuda') prompt = f'Prompt: {message}\nAnswer:' pred_text = generator(prompt, pad_token_id=tokenizer.eos_token_id, num_beams=int(num_beams), max_length=1024, min_length=0, temperature=temperature, do_sample=do_sample, top_p=top_p, top_k=int(top_k), repetition_penalty=3.0, no_repeat_ngram_size=3)[0]['generated_text'] try: pred_sentiment = re.findall("Answer:(.*)", pred_text, re.S)[-1] except: pred_sentiment = "لم أستطع توليد إجابة. يرجى إعادة صياغة السؤال." # Move model back to CPU to free GPU memory generator.model = generator.model.to('cpu') return pred_sentiment # Properly format the chat message handler def respond(message, chat_history, num_beams, temperature, do_sample, top_k, top_p): bot_message = generate_response(message, chat_history, num_beams, temperature, do_sample, top_k, top_p) chat_history.append((message, bot_message)) return "", chat_history # CSS for RTL support and styling css = """ .gradio-container {direction: rtl;} .message.user {background-color: #2b5797; color: white; border-radius: 20px; padding: 8px 12px; margin-bottom: 8px; text-align: right;} .message.bot {background-color: #f0f0f0; color: black; border-radius: 20px; padding: 8px 12px; margin-bottom: 8px; text-align: right;} """ with gr.Blocks(css=css) as demo: gr.Markdown("# نظام AskMe - تحدث معي") with gr.Row(): with gr.Column(scale=3): chatbot = gr.Chatbot(label="المحادثة", elem_classes=["chatbot"]) msg = gr.Textbox(label="اكتب رسالتك هنا", placeholder="اكتب هنا...") with gr.Row(): submit_btn = gr.Button("إرسال", variant="primary") clear_btn = gr.Button("مسح المحادثة") with gr.Column(scale=1): with gr.Accordion("إعدادات توليد النص", open=False): num_beams = gr.Slider(1, 10, value=4, step=1, label="عدد الشعاعات") temperature = gr.Slider(0.1, 2.0, value=0.99, step=0.01, label="درجة الحرارة") do_sample = gr.Checkbox(value=True, label="النمط الإبداعي") top_k = gr.Slider(1, 100, value=60, step=1, label="Top-K") top_p = gr.Slider(0.0, 1.0, value=0.9, step=0.01, label="Top-P") # Example prompts examples = gr.Examples( examples=[ ["اكتب مقال عن الذكاء الصناعي"], ["اكتب قصة قصيرة عن النجاح"], ["كيف يمكن المحافظة على حياه صحية"] ], inputs=msg ) # Set up event handlers submit_btn.click( respond, inputs=[msg, chatbot, num_beams, temperature, do_sample, top_k, top_p], outputs=[msg, chatbot] ) msg.submit( respond, inputs=[msg, chatbot, num_beams, temperature, do_sample, top_k, top_p], outputs=[msg, chatbot] ) clear_btn.click(lambda: None, None, chatbot, queue=False) demo.launch()