Backend / app /chains.py
Damien Benveniste
modified
36118aa
raw
history blame
1.19 kB
import os
from langchain_huggingface import HuggingFaceEndpoint
from langchain_core.runnables import RunnablePassthrough
import schemas
from prompts import (
raw_prompt_formatted,
history_prompt_formatted,
question_prompt_formatted,
context_prompt_formatted,
format_context,
tokenizer
)
from data_indexing import DataIndexer
data_indexer = DataIndexer()
llm = HuggingFaceEndpoint(
repo_id="meta-llama/Meta-Llama-3-8B-Instruct",
huggingfacehub_api_token=os.environ['HF_TOKEN'],
max_new_tokens=512,
stop_sequences=[tokenizer.eos_token]
)
formatted_chain = (
raw_prompt_formatted
| llm
).with_types(input_type=schemas.UserQuestion)
history_chain = (
history_prompt_formatted
| llm
).with_types(input_type=schemas.HistoryInput)
rag_chain = (
{
'question': question_prompt_formatted | llm,
'hybrid_search': RunnablePassthrough()
}
| {
'context': lambda x: format_context(data_indexer.search(x['question'], hybrid_search=x['hybrid_search'])),
'standalone_question': lambda x: x['question']
}
| context_prompt_formatted
| llm
).with_types(input_type=schemas.RagInput)