import gradio as gr import requests from typing import List, Dict, Tuple from flask import Flask, request, jsonify from transformers import AutoTokenizer, AutoModelForCausalLM import threading # Define the API URL to use the internal server API_URL = "http://localhost:5000/chat" History = List[Tuple[str, str]] Messages = List[Dict[str, str]] app = Flask(__name__) # Load the model and tokenizer model_name = "path/to/your/dictalm2.0" tokenizer = AutoTokenizer.from_pretrained(model_name) model = AutoModelForCausalLM.from_pretrained(model_name) @app.route('/chat', methods=['POST']) def chat(): data = request.json messages = data.get('messages', []) if not messages: return jsonify({"response": "No messages provided"}), 400 # Concatenate all user inputs into a single string user_input = " ".join([msg['content'] for msg in messages if msg['role'] == 'user']) inputs = tokenizer.encode(user_input, return_tensors='pt') outputs = model.generate(inputs) response_text = tokenizer.decode(outputs[0], skip_special_tokens=True) return jsonify({"response": response_text}) # Function to run the Flask app def run_flask(): app.run(host='0.0.0.0', port=5000) # Start the Flask app in a separate thread threading.Thread(target=run_flask).start() # Gradio interface functions def clear_session() -> History: return [] def history_to_messages(history: History) -> Messages: messages = [] for h in history: messages.append({'role': 'user', 'content': h[0].strip()}) messages.append({'role': 'assistant', 'content': h[1].strip()}) return messages def messages_to_history(messages: Messages) -> History: history = [] for q, r in zip(messages[0::2], messages[1::2]): history.append((q['content'], r['content'])) return history def model_chat(query: str, history: History) -> Tuple[str, History]: if not query.strip(): return '', history messages = history_to_messages(history) messages.append({'role': 'user', 'content': query.strip()}) try: response = requests.post(API_URL, json={"messages": messages}) response.raise_for_status() # This will raise an HTTPError if the HTTP request returned an unsuccessful status code response_json = response.json() response_text = response_json.get("response", "Error: Response format is incorrect") except requests.exceptions.HTTPError as e: response_text = f"HTTPError: {str(e)}" print(f"HTTPError: {e.response.text}") # Detailed error message except requests.exceptions.RequestException as e: response_text = f"RequestException: {str(e)}" print(f"RequestException: {e}") # Debug print statement except ValueError as e: response_text = "ValueError: Invalid JSON response" print(f"ValueError: {e}") # Debug print statement except Exception as e: response_text = f"Exception: {str(e)}" print(f"General Exception: {e}") # Debug print statement history.append((query.strip(), response_text.strip())) return response_text.strip(), history # Gradio interface setup with gr.Blocks(css=''' .gr-group {direction: rtl;} .chatbot{text-align:right;} .dicta-header { background-color: var(--input-background-fill); border-radius: 10px; padding: 20px; text-align: center; display: flex; flex-direction: row; align-items: center; box-shadow: var(--block-shadow); border-color: var(--block-border-color); border-width: 1px; } @media (max-width: 768px) { .dicta-header { flex-direction: column; } } .chatbot.prose { font-size: 1.2em; } .dicta-logo { width: 150px; height: auto; margin-bottom: 20px; } .dicta-intro-text { margin-bottom: 20px; text-align: center; display: flex; flex-direction: column; align-items: center; width: 100%; font-size: 1.1em; } textarea { font-size: 1.2em; } ''', js=None) as demo: gr.Markdown("""
""") chatbot = gr.Chatbot(height=600) query = gr.Textbox(placeholder="הכנס שאלה בעברית (או באנגלית!)", rtl=True) clear_btn = gr.Button("נקה שיחה") def respond(query, history): print(f"Query: {query}") # Debug print statement response, history = model_chat(query, history) print(f"Response: {response}") # Debug print statement return history, gr.update(value="", interactive=True) demo_state = gr.State([]) query.submit(respond, [query, demo_state], [chatbot, query, demo_state]) clear_btn.click(clear_session, [], demo_state, chatbot) demo.launch()