Spaces:
Runtime error
Runtime error
File size: 6,339 Bytes
b228d02 6a9b716 02ac619 33ef101 6a9b716 384005b b228d02 02ac619 b228d02 02ac619 0cf6546 02ac619 16fb26d 6a9b716 02ac619 dc3d805 384005b dc3d805 02ac619 b228d02 384005b b228d02 02ac619 b228d02 02ac619 b228d02 6a9b716 b228d02 02ac619 b228d02 02ac619 b228d02 02ac619 384005b b228d02 4e6a827 b228d02 4e6a827 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 |
import gradio as gr
import requests
from typing import List, Dict, Tuple
from flask import Flask, request, jsonify, send_from_directory
from transformers import AutoTokenizer, AutoModelForCausalLM
import threading
import torch
import os
import re
# Define the API URL to use the internal server
API_URL = "http://localhost:5000/chat"
History = List[Tuple[str, str]]
Messages = List[Dict[str, str]]
app = Flask(__name__)
# Load the model and tokenizer
model_name = "dicta-il/dictalm2.0-instruct"
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16)
tokenizer = AutoTokenizer.from_pretrained(model_name)
# Set the pad_token to eos_token if not already set
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
# Route to serve static files (e.g., images)
@app.route('/static/<path:path>')
def send_static(path):
return send_from_directory('static', path)
@app.route('/chat', methods=['POST'])
def chat():
data = request.json
messages = data.get('messages', [])
if not messages:
return jsonify({"response": "No messages provided"}), 400
# Concatenate all user inputs into a single string
user_input = " ".join([msg['content'] for msg in messages if msg['role'] == 'user'])
inputs = tokenizer(user_input, return_tensors='pt', padding=True, truncation=True)
input_ids = inputs['input_ids']
attention_mask = inputs['attention_mask']
outputs = model.generate(input_ids, attention_mask=attention_mask, max_new_tokens=1000, pad_token_id=tokenizer.eos_token_id)
response_text = tokenizer.decode(outputs[0], skip_special_tokens=True).replace(user_input, '').strip()
return jsonify({"response": response_text})
# Function to run the Flask app
def run_flask():
app.run(host='0.0.0.0', port=5000)
# Start the Flask app in a separate thread
threading.Thread(target=run_flask).start()
# Gradio interface functions
def clear_session() -> History:
return []
def history_to_messages(history: History) -> Messages:
messages = []
for h in history:
messages.append({'role': 'user', 'content': h[0].strip()})
messages.append({'role': 'assistant', 'content': h[1].strip()})
return messages
def messages_to_history(messages: Messages) -> History:
history = []
for q, r in zip(messages[0::2], messages[1::2]):
history.append((q['content'], r['content']))
return history
def is_hebrew(text: str) -> bool:
return bool(re.search(r'[\u0590-\u05FF]', text))
def model_chat(query: str, history: History) -> Tuple[str, History]:
if not query.strip():
return '', history
messages = history_to_messages(history)
messages.append({'role': 'user', 'content': query.strip()})
try:
response = requests.post(API_URL, json={"messages": messages})
response.raise_for_status() # This will raise an HTTPError if the HTTP request returned an unsuccessful status code
response_json = response.json()
response_text = response_json.get("response", "Error: Response format is incorrect")
except requests.exceptions.HTTPError as e:
response_text = f"HTTPError: {str(e)}"
print(f"HTTPError: {e.response.text}") # Detailed error message
except requests.exceptions.RequestException as e:
response_text = f"RequestException: {str(e)}"
print(f"RequestException: {e}") # Debug print statement
except ValueError as e:
response_text = "ValueError: Invalid JSON response"
print(f"ValueError: {e}") # Debug print statement
except Exception as e:
response_text = f"Exception: {str(e)}"
print(f"General Exception: {e}") # Debug print statement
history.append((query.strip(), response_text.strip()))
return response_text.strip(), history
# Gradio interface setup
with gr.Blocks(css='''
.gr-group {direction: rtl;}
.chatbot{text-align:right;}
.dicta-header {
background-color: var(--input-background-fill);
border-radius: 10px;
padding: 20px;
text-align: center;
display: flex;
flex-direction: row;
align-items: center;
box-shadow: var(--block-shadow);
border-color: var(--block-border-color);
border-width: 1px;
}
@media (max-width: 768px) {
.dicta-header {
flex-direction: column;
}
}
.chatbot.prose {
font-size: 1.2em;
}
.dicta-logo {
width: 150px;
height: auto;
margin-bottom: 20px;
}
.dicta-intro-text {
margin-bottom: 20px;
text-align: center;
display: flex;
flex-direction: column;
align-items: center;
width: 100%;
font-size: 1.1em;
}
textarea {
font-size: 1.2em;
}
''', js=None) as demo:
gr.Markdown("""
<div class="dicta-header">
<a href="/static/logo111.png">
<img src="/static/logo111.png" alt="Logo" class="dicta-logo">
</a>
<div class="dicta-intro-text">
<h1>爪'讗讟 诪注专讻讬 - 讛讚讙诪讛 专讗砖讜谞讬转</h1>
<span dir='rtl'>讘专讜讻讬诐 讛讘讗讬诐 诇讚诪讜 讛讗讬谞讟专讗拽讟讬讘讬 讛专讗砖讜谉. 讞拽专讜 讗转 讬讻讜诇讜转 讛诪讜讚诇 讜专讗讜 讻讬爪讚 讛讜讗 讬讻讜诇 诇住讬讬注 诇讻诐 讘诪砖讬诪讜转讬讻诐</span><br/>
<span dir='rtl'>讛讚诪讜 谞讻转讘 注诇 讬讚讬 住专谉 专讜注讬 专转诐 转讜讱 砖讬诪讜砖 讘诪讜讚诇 砖驻讛 讚讬拽讟讛 砖驻讜转讞 注诇 讬讚讬 诪驻讗"转</span><br/>
</div>
</div>
""")
chatbot = gr.Chatbot(height=600)
query = gr.Textbox(placeholder="讛讻谞住 砖讗诇讛 讘注讘专讬转 (讗讜 讘讗谞讙诇讬转!)", rtl=True)
clear_btn = gr.Button("谞拽讛 砖讬讞讛")
def respond(query, history):
print(f"Query: {query}") # Debug print statement
response, history = model_chat(query, history)
print(f"Response: {response}") # Debug print statement
if is_hebrew(response):
return history, gr.update(value="", interactive=True, lines=2, rtl=True), history
else:
return history, gr.update(value="", interactive=True, lines=2, rtl=False), history
demo_state = gr.State([])
query.submit(respond, [query, demo_state], [chatbot, query, demo_state])
clear_btn.click(clear_session, [], [chatbot, demo_state])
demo.launch(share=True)
|