Damien Benveniste commited on
Commit
5d0492c
·
1 Parent(s): 36118aa

fixed issues

Browse files
Files changed (3) hide show
  1. app/chains.py +3 -0
  2. app/main.py +2 -2
  3. app/prompts.py +1 -0
app/chains.py CHANGED
@@ -3,6 +3,7 @@ from langchain_huggingface import HuggingFaceEndpoint
3
  from langchain_core.runnables import RunnablePassthrough
4
  import schemas
5
  from prompts import (
 
6
  raw_prompt_formatted,
7
  history_prompt_formatted,
8
  question_prompt_formatted,
@@ -22,6 +23,8 @@ llm = HuggingFaceEndpoint(
22
  stop_sequences=[tokenizer.eos_token]
23
  )
24
 
 
 
25
  formatted_chain = (
26
  raw_prompt_formatted
27
  | llm
 
3
  from langchain_core.runnables import RunnablePassthrough
4
  import schemas
5
  from prompts import (
6
+ raw_prompt,
7
  raw_prompt_formatted,
8
  history_prompt_formatted,
9
  question_prompt_formatted,
 
23
  stop_sequences=[tokenizer.eos_token]
24
  )
25
 
26
+ simple_chain = (raw_prompt | llm).with_types(input_type=schemas.UserQuestion)
27
+
28
  formatted_chain = (
29
  raw_prompt_formatted
30
  | llm
app/main.py CHANGED
@@ -7,7 +7,7 @@ from langserve.serialization import WellKnownLCSerializer
7
  from typing import Any, List
8
  import crud, models, schemas
9
  from database import SessionLocal, engine
10
- from chains import llm, formatted_chain, history_chain, rag_chain
11
  from prompts import format_chat_history
12
  from callbacks import LogResponseCallback
13
 
@@ -40,7 +40,7 @@ def greet_json():
40
  async def simple_stream(request: Request):
41
  data = await request.json()
42
  user_question = schemas.UserQuestion(**data['input'])
43
- return EventSourceResponse(generate_stream(user_question, llm))
44
 
45
 
46
  @app.post("/formatted/stream")
 
7
  from typing import Any, List
8
  import crud, models, schemas
9
  from database import SessionLocal, engine
10
+ from chains import llm, simple_chain, formatted_chain, history_chain, rag_chain
11
  from prompts import format_chat_history
12
  from callbacks import LogResponseCallback
13
 
 
40
  async def simple_stream(request: Request):
41
  data = await request.json()
42
  user_question = schemas.UserQuestion(**data['input'])
43
+ return EventSourceResponse(generate_stream(user_question, simple_chain))
44
 
45
 
46
  @app.post("/formatted/stream")
app/prompts.py CHANGED
@@ -77,6 +77,7 @@ def format_context(docs: List[str]):
77
 
78
 
79
  raw_prompt_formatted = format_prompt(raw_prompt)
 
80
  history_prompt_formatted = format_prompt(history_prompt)
81
  question_prompt_formatted = format_prompt(question_prompt)
82
  context_prompt_formatted = format_prompt(context_prompt)
 
77
 
78
 
79
  raw_prompt_formatted = format_prompt(raw_prompt)
80
+ raw_prompt = PromptTemplate.from_template(raw_prompt)
81
  history_prompt_formatted = format_prompt(history_prompt)
82
  question_prompt_formatted = format_prompt(question_prompt)
83
  context_prompt_formatted = format_prompt(context_prompt)