Spaces:
Sleeping
Sleeping
# src/main.py | |
from src.agent import Agent | |
from src.create_database import load_and_process_dataset # Import from create_database.py | |
import os | |
import uuid | |
import requests | |
import logging | |
from llama_cpp import Llama | |
# Configure logging | |
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') | |
# Create the directory if it doesn't exist | |
local_dir = "models" | |
os.makedirs(local_dir, exist_ok=True) | |
# Specify the filename for the model | |
filename = "unsloth.Q4_K_M.gguf" | |
model_path = os.path.join(local_dir, filename) | |
# Function to download the model file | |
def download_model(repo_id, filename, save_path): | |
# Construct the URL for the model file | |
url = f"https://huggingface.co/{repo_id}/resolve/main/{filename}" | |
# Download the model file | |
response = requests.get(url) | |
if response.status_code == 200: | |
with open(save_path, 'wb') as f: | |
f.write(response.content) | |
print(f"Model downloaded to {save_path}") | |
else: | |
print(f"Failed to download model: {response.status_code}") | |
# Download the model if it doesn't exist | |
if not os.path.exists(model_path): | |
download_model("adeptusnull/llama3.2-1b-wizardml-vicuna-uncensored-finetune-test", filename, model_path) | |
def main(): | |
model_path = "models/unsloth.Q4_K_M.gguf" # Path to the downloaded model | |
db_path = "agent.db" | |
system_prompt = "Vous êtes l'assistant intelligent de Les Chronique MTC. Votre rôle est d'aider les visiteurs en expliquant le contenu des Chroniques, Flash Infos et Chronique-FAQ de Michel Thomas. Utilisez le contexte fourni pour améliorer vos réponses et veillez à ce qu'elles soient précises et pertinentes." | |
max_tokens = 500 | |
temperature = 0.7 | |
top_p = 0.95 | |
# Check if the database exists, if not, initialize it | |
if not os.path.exists(db_path): | |
data_update_path = "data-update.txt" | |
keyword_dir = "keyword" # Updated keyword directory | |
load_and_process_dataset(data_update_path, keyword_dir, db_path) | |
# Load the model | |
llm = Llama( | |
model_path=model_path, | |
n_ctx=572, # Set the maximum context length | |
max_tokens=max_tokens # Control the maximum number of tokens generated in the response | |
) | |
agent = Agent(llm, db_path, system_prompt, max_tokens, temperature, top_p) | |
while True: | |
user_id = str(uuid.uuid4()) # Generate a unique user ID for each session | |
user_query = input("Entrez votre requête: ") | |
if user_query.lower() == 'exit': | |
break | |
try: | |
response = agent.process_query(user_id, user_query) | |
print("Réponse:", response) | |
except Exception as e: | |
print(f"Erreur lors du traitement de la requête: {e}") | |
# Clean up expired interactions | |
agent.memory.cleanup_expired_interactions() | |
if __name__ == "__main__": | |
main() | |