demo_1 / app.py
Royrotem100's picture
Initial commit
02ac619
raw
history blame
5.44 kB
import gradio as gr
import requests
from typing import List, Dict, Tuple
from flask import Flask, request, jsonify
from transformers import AutoTokenizer, AutoModelForCausalLM
import threading
# Define the API URL to use the internal server
API_URL = "http://localhost:5000/chat"
History = List[Tuple[str, str]]
Messages = List[Dict[str, str]]
app = Flask(__name__)
# Load the model and tokenizer
model_name = "path/to/your/dictalm2.0"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)
@app.route('/chat', methods=['POST'])
def chat():
data = request.json
messages = data.get('messages', [])
if not messages:
return jsonify({"response": "No messages provided"}), 400
# Concatenate all user inputs into a single string
user_input = " ".join([msg['content'] for msg in messages if msg['role'] == 'user'])
inputs = tokenizer.encode(user_input, return_tensors='pt')
outputs = model.generate(inputs)
response_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
return jsonify({"response": response_text})
# Function to run the Flask app
def run_flask():
app.run(host='0.0.0.0', port=5000)
# Start the Flask app in a separate thread
threading.Thread(target=run_flask).start()
# Gradio interface functions
def clear_session() -> History:
return []
def history_to_messages(history: History) -> Messages:
messages = []
for h in history:
messages.append({'role': 'user', 'content': h[0].strip()})
messages.append({'role': 'assistant', 'content': h[1].strip()})
return messages
def messages_to_history(messages: Messages) -> History:
history = []
for q, r in zip(messages[0::2], messages[1::2]):
history.append((q['content'], r['content']))
return history
def model_chat(query: str, history: History) -> Tuple[str, History]:
if not query.strip():
return '', history
messages = history_to_messages(history)
messages.append({'role': 'user', 'content': query.strip()})
try:
response = requests.post(API_URL, json={"messages": messages})
response.raise_for_status() # This will raise an HTTPError if the HTTP request returned an unsuccessful status code
response_json = response.json()
response_text = response_json.get("response", "Error: Response format is incorrect")
except requests.exceptions.HTTPError as e:
response_text = f"HTTPError: {str(e)}"
print(f"HTTPError: {e.response.text}") # Detailed error message
except requests.exceptions.RequestException as e:
response_text = f"RequestException: {str(e)}"
print(f"RequestException: {e}") # Debug print statement
except ValueError as e:
response_text = "ValueError: Invalid JSON response"
print(f"ValueError: {e}") # Debug print statement
except Exception as e:
response_text = f"Exception: {str(e)}"
print(f"General Exception: {e}") # Debug print statement
history.append((query.strip(), response_text.strip()))
return response_text.strip(), history
# Gradio interface setup
with gr.Blocks(css='''
.gr-group {direction: rtl;}
.chatbot{text-align:right;}
.dicta-header {
background-color: var(--input-background-fill);
border-radius: 10px;
padding: 20px;
text-align: center;
display: flex;
flex-direction: row;
align-items: center;
box-shadow: var(--block-shadow);
border-color: var(--block-border-color);
border-width: 1px;
}
@media (max-width: 768px) {
.dicta-header {
flex-direction: column;
}
}
.chatbot.prose {
font-size: 1.2em;
}
.dicta-logo {
width: 150px;
height: auto;
margin-bottom: 20px;
}
.dicta-intro-text {
margin-bottom: 20px;
text-align: center;
display: flex;
flex-direction: column;
align-items: center;
width: 100%;
font-size: 1.1em;
}
textarea {
font-size: 1.2em;
}
''', js=None) as demo:
gr.Markdown("""
<div class="dicta-header">
<a href="">
<img src="\\logo111.png" alt="Logo" class="dicta-logo">
</a>
<div class="dicta-intro-text">
<h1>爪'讗讟 诪注专讻讬 - 讛讚讙诪讛 专讗砖讜谞讬转</h1>
<span dir='rtl'>讘专讜讻讬诐 讛讘讗讬诐 诇讚诪讜 讛讗讬谞讟专讗拽讟讬讘讬 讛专讗砖讜谉. 讞拽专讜 讗转 讬讻讜诇讜转 讛诪讜讚诇 讜专讗讜 讻讬爪讚 讛讜讗 讬讻讜诇 诇住讬讬注 诇讻诐 讘诪砖讬诪讜转讬讻诐</span><br/>
<span dir='rtl'>讛讚诪讜 谞讻转讘 注诇 讬讚讬 住专谉 专讜注讬 专转诐 转讜讱 砖讬诪讜砖 讘诪讜讚诇 砖驻讛 讚讬拽讟讛 砖驻讜转讞 注诇 讬讚讬 诪驻讗"转</span><br/>
</div>
</div>
""")
chatbot = gr.Chatbot(height=600)
query = gr.Textbox(placeholder="讛讻谞住 砖讗诇讛 讘注讘专讬转 (讗讜 讘讗谞讙诇讬转!)", rtl=True)
clear_btn = gr.Button("谞拽讛 砖讬讞讛")
def respond(query, history):
print(f"Query: {query}") # Debug print statement
response, history = model_chat(query, history)
print(f"Response: {response}") # Debug print statement
return history, gr.update(value="", interactive=True)
demo_state = gr.State([])
query.submit(respond, [query, demo_state], [chatbot, query, demo_state])
clear_btn.click(clear_session, [], demo_state, chatbot)
demo.launch()