mister-g commited on
Commit
9cc5c95
·
1 Parent(s): d8e70c0

direct access to hf pipeline

Browse files
Files changed (1) hide show
  1. app.py +13 -6
app.py CHANGED
@@ -13,6 +13,7 @@ from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
13
  # gpt_model = 'gpt-4-1106-preview'
14
  # embedding_model = 'text-embedding-3-small'
15
  default_model_id = "bigcode/starcoder2-3b"
 
16
 
17
  def init():
18
  if "conversation" not in st.session_state:
@@ -38,7 +39,7 @@ def init_llm_pipeline(model_id):
38
  task="text-generation",
39
  max_new_tokens=1024
40
  )
41
- st.session_state.llm = HuggingFacePipeline(pipeline=text_generation_pipeline)
42
 
43
  def get_retriever(files):
44
  documents = [doc.getvalue().decode("utf-8") for doc in files]
@@ -58,17 +59,23 @@ def get_retriever(files):
58
  return retriever
59
 
60
  def get_conversation(retriever):
61
- memory = ConversationBufferMemory(memory_key='chat_history', return_messages=True)
 
62
  conversation_chain = ConversationalRetrievalChain.from_llm(
 
63
  llm=st.session_state.llm,
64
- retriever=retriever,
65
- memory = memory
66
  )
67
  return conversation_chain
68
 
 
 
 
 
69
  def handle_user_input(question):
70
- response = st.session_state.conversation({'question':question})
71
- st.session_state.chat_history = response['chat_history']
 
72
  for i, message in enumerate(st.session_state.chat_history):
73
  if i % 2 == 0:
74
  with st.chat_message("user"):
 
13
  # gpt_model = 'gpt-4-1106-preview'
14
  # embedding_model = 'text-embedding-3-small'
15
  default_model_id = "bigcode/starcoder2-3b"
16
+ #default_model_id = "tiiuae/falcon-7b-instruct"
17
 
18
  def init():
19
  if "conversation" not in st.session_state:
 
39
  task="text-generation",
40
  max_new_tokens=1024
41
  )
42
+ st.session_state.llm = text_generation_pipeline
43
 
44
  def get_retriever(files):
45
  documents = [doc.getvalue().decode("utf-8") for doc in files]
 
59
  return retriever
60
 
61
  def get_conversation(retriever):
62
+ #memory = ConversationBufferMemory(memory_key='chat_history', return_messages=True)
63
+
64
  conversation_chain = ConversationalRetrievalChain.from_llm(
65
+ prompt=prompt,
66
  llm=st.session_state.llm,
67
+ retriever=retriever
 
68
  )
69
  return conversation_chain
70
 
71
+ def getprompt(user_input):
72
+ prompt = f"You are a helpful assistant. Please answer the user question. USER: {user_input} ASSISTANT:"
73
+ return prompt
74
+
75
  def handle_user_input(question):
76
+ st.session_state.chat_history += {"role":"user","content":question}
77
+ response = st.session_state.llm(getprompt(question))
78
+ st.session_state.chat_history += {"role":"assistant","content":response}
79
  for i, message in enumerate(st.session_state.chat_history):
80
  if i % 2 == 0:
81
  with st.chat_message("user"):