Spaces:
Runtime error
Runtime error
Damien Benveniste
commited on
Commit
·
5d0492c
1
Parent(s):
36118aa
fixed issues
Browse files- app/chains.py +3 -0
- app/main.py +2 -2
- app/prompts.py +1 -0
app/chains.py
CHANGED
@@ -3,6 +3,7 @@ from langchain_huggingface import HuggingFaceEndpoint
|
|
3 |
from langchain_core.runnables import RunnablePassthrough
|
4 |
import schemas
|
5 |
from prompts import (
|
|
|
6 |
raw_prompt_formatted,
|
7 |
history_prompt_formatted,
|
8 |
question_prompt_formatted,
|
@@ -22,6 +23,8 @@ llm = HuggingFaceEndpoint(
|
|
22 |
stop_sequences=[tokenizer.eos_token]
|
23 |
)
|
24 |
|
|
|
|
|
25 |
formatted_chain = (
|
26 |
raw_prompt_formatted
|
27 |
| llm
|
|
|
3 |
from langchain_core.runnables import RunnablePassthrough
|
4 |
import schemas
|
5 |
from prompts import (
|
6 |
+
raw_prompt,
|
7 |
raw_prompt_formatted,
|
8 |
history_prompt_formatted,
|
9 |
question_prompt_formatted,
|
|
|
23 |
stop_sequences=[tokenizer.eos_token]
|
24 |
)
|
25 |
|
26 |
+
simple_chain = (raw_prompt | llm).with_types(input_type=schemas.UserQuestion)
|
27 |
+
|
28 |
formatted_chain = (
|
29 |
raw_prompt_formatted
|
30 |
| llm
|
app/main.py
CHANGED
@@ -7,7 +7,7 @@ from langserve.serialization import WellKnownLCSerializer
|
|
7 |
from typing import Any, List
|
8 |
import crud, models, schemas
|
9 |
from database import SessionLocal, engine
|
10 |
-
from chains import llm, formatted_chain, history_chain, rag_chain
|
11 |
from prompts import format_chat_history
|
12 |
from callbacks import LogResponseCallback
|
13 |
|
@@ -40,7 +40,7 @@ def greet_json():
|
|
40 |
async def simple_stream(request: Request):
|
41 |
data = await request.json()
|
42 |
user_question = schemas.UserQuestion(**data['input'])
|
43 |
-
return EventSourceResponse(generate_stream(user_question,
|
44 |
|
45 |
|
46 |
@app.post("/formatted/stream")
|
|
|
7 |
from typing import Any, List
|
8 |
import crud, models, schemas
|
9 |
from database import SessionLocal, engine
|
10 |
+
from chains import llm, simple_chain, formatted_chain, history_chain, rag_chain
|
11 |
from prompts import format_chat_history
|
12 |
from callbacks import LogResponseCallback
|
13 |
|
|
|
40 |
async def simple_stream(request: Request):
|
41 |
data = await request.json()
|
42 |
user_question = schemas.UserQuestion(**data['input'])
|
43 |
+
return EventSourceResponse(generate_stream(user_question, simple_chain))
|
44 |
|
45 |
|
46 |
@app.post("/formatted/stream")
|
app/prompts.py
CHANGED
@@ -77,6 +77,7 @@ def format_context(docs: List[str]):
|
|
77 |
|
78 |
|
79 |
raw_prompt_formatted = format_prompt(raw_prompt)
|
|
|
80 |
history_prompt_formatted = format_prompt(history_prompt)
|
81 |
question_prompt_formatted = format_prompt(question_prompt)
|
82 |
context_prompt_formatted = format_prompt(context_prompt)
|
|
|
77 |
|
78 |
|
79 |
raw_prompt_formatted = format_prompt(raw_prompt)
|
80 |
+
raw_prompt = PromptTemplate.from_template(raw_prompt)
|
81 |
history_prompt_formatted = format_prompt(history_prompt)
|
82 |
question_prompt_formatted = format_prompt(question_prompt)
|
83 |
context_prompt_formatted = format_prompt(context_prompt)
|