HealsmindAIPetals / model.py
lavanjv's picture
Update model.py
46acb67
raw
history blame
No virus
3.15 kB
import os
import torch
from transformers import AutoTokenizer
from petals import AutoDistributedModelForCausalLM
import chainlit as cl
from huggingface_hub import login
from dotenv import load_dotenv
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.vectorstores import FAISS
from langchain import PromptTemplate
# Load environment variables from .env file
load_dotenv()
# Retrieve Hugging Face token from environment variables
hugging_face_token = os.getenv("HUGGINGFACE_TOKEN")
DB_FAISS_PATH = 'vectorstore/db_faiss'
# Login with Hugging Face token
login(token=hugging_face_token)
# Load SentenceEncoder model
def load_vector_store():
embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2",
model_kwargs={'device': 'cpu'})
db = FAISS.load_local(DB_FAISS_PATH, embeddings)
return db
# Loading the model
def load_llm():
model_name = "meta-llama/Llama-2-70b-chat-hf"
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False, add_bos_token=False)
model = AutoDistributedModelForCausalLM.from_pretrained(model_name)
return model, tokenizer
# QA Model Function
def qa_bot():
model, tokenizer = load_llm()
return model, tokenizer
# Initialize conversational history
conversational_history = []
# chainlit code
@cl.on_chat_start
async def start():
model, tokenizer = qa_bot()
msg = cl.Message(content="Starting the bot...")
await msg.send()
msg.content = "Hi, Welcome to HealsMindAI. What is your query?"
await msg.update()
cl.user_session.set("model", model)
cl.user_session.set("tokenizer", tokenizer)
cl.user_session.set("history", conversational_history)
@cl.on_message
async def main(message):
model = cl.user_session.get("model")
tokenizer = cl.user_session.get("tokenizer")
history = cl.user_session.get("history")
# Use the history to provide context for the query
query_with_history = " ".join(history + [message.content])
custom_prompt_template = """Use the following pieces of information to answer the user's question.
If you don't know the answer, just say that you don't know, don't try to make up an answer.
Context: {}
Question: {}
Only return the helpful answer below and nothing else.
Helpful answer:
""".format(query_with_history, message.content)
# Generate text using the LLM model and the custom prompt
generated_output = model.generate(tokenizer.encode(custom_prompt_template, return_tensors="pt"),
max_length=200, # Set the desired maximum length of generated text
num_return_sequences=1,
no_repeat_ngram_size=2)
# Convert generated output to text using the tokenizer
decoded_output = tokenizer.decode(generated_output[0], skip_special_tokens=True)
# Update conversational history
history.append(decoded_output) # Append the generated output to the history
cl.user_session.set("history", history)
await cl.Message(content=decoded_output).send()