Spaces:
Sleeping
Sleeping
File size: 4,618 Bytes
f846bf2 5e0a7d2 f846bf2 064bf6d a639627 d9a8c3c 5e0a7d2 f846bf2 064bf6d d9a8c3c 064bf6d d9a8c3c f846bf2 5e0a7d2 f846bf2 793aa71 5e0a7d2 f846bf2 5e0a7d2 f846bf2 5e0a7d2 f846bf2 793aa71 846b33f f846bf2 5e0a7d2 f846bf2 5e0a7d2 793aa71 f846bf2 5e0a7d2 f846bf2 5e0a7d2 d9a8c3c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 |
# app.py
import gradio as gr
from src.agent import Agent
from src.create_database import load_and_process_dataset # Import from create_database.py
import os
import uuid
import requests
import logging
import subprocess
from llama_cpp import Llama # Import Llama from llama_cpp
import spacy
# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
# Function to install requirements
def install_requirements():
try:
subprocess.check_call([os.sys.executable, '-m', 'pip', 'install', '-r', 'requirements.txt'])
logging.info("Requirements installed successfully.")
except subprocess.CalledProcessError as e:
logging.error(f"Failed to install requirements: {e}")
# Function to download the spaCy model
def download_spacy_model(model_name):
try:
subprocess.check_call([os.sys.executable, '-m', 'spacy', 'download', model_name])
logging.info(f"SpaCy model {model_name} downloaded successfully.")
except subprocess.CalledProcessError as e:
logging.error(f"Failed to download SpaCy model {model_name}: {e}")
# Install requirements
install_requirements()
# Download the spaCy model if it doesn't exist
if not spacy.util.is_package('en_core_web_lg'):
download_spacy_model('en_core_web_lg')
# Create the directory if it doesn't exist
local_dir = "models"
os.makedirs(local_dir, exist_ok=True)
# Specify the filename for the model
filename = "unsloth.Q4_K_M.gguf"
model_path = os.path.join(local_dir, filename)
# Function to download the model file
def download_model(repo_id, filename, save_path):
# Construct the URL for the model file
url = f"https://huggingface.co/{repo_id}/resolve/main/{filename}"
# Download the model file
response = requests.get(url)
if response.status_code == 200:
with open(save_path, 'wb') as f:
f.write(response.content)
print(f"Model downloaded to {save_path}")
else:
print(f"Failed to download model: {response.status_code}")
# Download the model if it doesn't exist
if not os.path.exists(model_path):
download_model("adeptusnull/llama3.2-1b-wizardml-vicuna-uncensored-finetune-test", filename, model_path)
# Function to truncate context to fit within the model's context window
def truncate_context(context, max_tokens):
words = context.split()
truncated_context = ' '.join(words[-max_tokens:])
return truncated_context
def respond(
message,
history: list[tuple[str, str]],
system_message,
):
model_path = "models/unsloth.Q4_K_M.gguf" # Path to the downloaded model
db_path = "agent.db"
system_prompt = system_message
# Check if the database exists, if not, initialize it
if not os.path.exists(db_path):
data_update_path = "data-update.txt"
keyword_dir = "keyword" # Updated keyword directory
load_and_process_dataset(data_update_path, keyword_dir, db_path)
# Load the model with the maximum context length and control the maximum tokens in the response
llm = Llama(
model_path=model_path,
n_ctx=500, # Set the maximum context length
max_tokens=500 # Control the maximum number of tokens generated in the response
)
agent = Agent(llm, db_path, system_prompt)
user_id = str(uuid.uuid4()) # Generate a unique user ID for each session
try:
# Truncate the context to fit within the model's context window
max_context_tokens = 500 # Adjust this based on your model's context window
context = f"{system_prompt}\nUser: {message}\nAssistant: "
truncated_context = truncate_context(context, max_context_tokens)
response = agent.process_query(user_id, message, truncated_context)
except ValueError as e:
logging.error(f"Error during processing: {e}")
response = "Désolé, il y a eu une erreur lors du traitement de votre requête. Veuillez essayer à nouveau."
return response
"""
For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
"""
demo = gr.ChatInterface(
respond,
additional_inputs=[
gr.Textbox(value="Vous êtes l'assistant intelligent de Les Chronique MTC. Votre rôle est d'aider les visiteurs en expliquant le contenu des Chroniques, Flash Infos et Chronique-FAQ de Michel Thomas. Utilisez le contexte fourni pour améliorer vos réponses et veillez à ce qu'elles soient précises et pertinentes.", label="System message"),
],
)
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=7860)
|