Backend / app /chains.py
Damien Benveniste
corrected
275d03f
import os
from langchain_huggingface import HuggingFaceEndpoint
from langchain_core.runnables import RunnablePassthrough
import schemas
from prompts import (
raw_prompt,
raw_prompt_formatted,
history_prompt_formatted,
standalone_prompt_formatted,
rag_prompt_formatted,
format_context,
tokenizer
)
from data_indexing import DataIndexer
data_indexer = DataIndexer()
llm = HuggingFaceEndpoint(
repo_id="meta-llama/Meta-Llama-3-8B-Instruct",
huggingfacehub_api_token=os.environ['HF_TOKEN'],
max_new_tokens=512,
stop_sequences=[tokenizer.eos_token],
streaming=True,
)
simple_chain = (raw_prompt | llm).with_types(input_type=schemas.UserQuestion)
formatted_chain = (
raw_prompt_formatted
| llm
).with_types(input_type=schemas.UserQuestion)
history_chain = (
history_prompt_formatted
| llm
).with_types(input_type=schemas.HistoryInput)
rag_chain = (
RunnablePassthrough.assign(new_question=standalone_prompt_formatted | llm)
| {
'context': lambda x: format_context(data_indexer.search(x['new_question'], hybrid_search=x['hybrid_search'])),
'standalone_question': lambda x: x['new_question'],
'test': lambda x : print(x)
}
| rag_prompt_formatted
| llm
).with_types(input_type=schemas.RagInput)