AminFaraji commited on
Commit
4b96563
·
verified ·
1 Parent(s): 7c79220

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +164 -0
app.py CHANGED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ try:
2
+ from langchain_community.vectorstores import Chroma
3
+ except:
4
+ from langchain_community.vectorstores import Chroma
5
+
6
+ from langchain.chains import ConversationChain
7
+ from langchain.chains.conversation.memory import ConversationBufferWindowMemory
8
+
9
+
10
+ # Import the necessary libraries.
11
+ from langchain_core.prompts import ChatPromptTemplate
12
+ from langchain_groq import ChatGroq
13
+
14
+
15
+
16
+
17
+ import os
18
+ import requests # Or your Groq library
19
+
20
+ groq_api_key = os.environ.get("my_groq_api_key")
21
+
22
+
23
+
24
+
25
+ # Initialize a ChatGroq object with a temperature of 0 and the "mixtral-8x7b-32768" model.
26
+ llm = ChatGroq(temperature=0, model_name="llama3-70b-8192",api_key=groq_api_key)
27
+
28
+ from langchain_community.embeddings import SentenceTransformerEmbeddings
29
+
30
+ embeddings = SentenceTransformerEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2", model_kwargs={"trust_remote_code":True})
31
+
32
+
33
+
34
+
35
+
36
+ memory = ConversationBufferWindowMemory(
37
+ memory_key="history", k=3, return_only_outputs=True
38
+ )
39
+
40
+
41
+
42
+
43
+
44
+
45
+ query_text="what did alice say to rabbit"
46
+
47
+ # Prepare the DB.
48
+ #embedding_function = OpenAIEmbeddings() # main
49
+
50
+ CHROMA_PATH = "chroma8"
51
+ # call the chroma generated in a directory
52
+ db = Chroma(persist_directory=CHROMA_PATH, embedding_function=embeddings)
53
+
54
+ # Search the DB for similar documents to the query.
55
+ results = db.similarity_search_with_relevance_scores(query_text, k=2)
56
+ if len(results) == 0 or results[0][1] < 0.5:
57
+ print(f"Unable to find matching results.")
58
+
59
+
60
+
61
+
62
+
63
+
64
+
65
+
66
+
67
+ from langchain import PromptTemplate
68
+ query_text = "when did alice see mad hatter"
69
+
70
+ results = db.similarity_search_with_relevance_scores(query_text, k=3)
71
+ if len(results) == 0 or results[0][1] < 0.5:
72
+ print(f"Unable to find matching results.")
73
+
74
+
75
+ context_text = "\n\n---\n\n".join([doc.page_content for doc, _score in results ])
76
+
77
+ template = """
78
+ The following is a conversation between a human an AI. Answer question based only on the conversation.
79
+
80
+ Current conversation:
81
+ {history}
82
+
83
+ """
84
+
85
+
86
+
87
+ s="""
88
+
89
+ \n question: {input}
90
+
91
+ \n answer:""".strip()
92
+
93
+
94
+ prompt = PromptTemplate(input_variables=["history", "input"], template=template+context_text+'\n'+s)
95
+
96
+
97
+
98
+
99
+ chain = ConversationChain(
100
+ llm=llm,
101
+
102
+ prompt=prompt,
103
+ memory=memory,
104
+ verbose=True,
105
+ )
106
+
107
+
108
+
109
+
110
+ # Generate a response from the Llama model
111
+ def get_llama_response(message: str, history: list) -> str:
112
+ """
113
+ Generates a conversational response from the Llama model.
114
+
115
+ Parameters:
116
+ message (str): User's input message.
117
+ history (list): Past conversation history.
118
+
119
+ Returns:
120
+ str: Generated response from the Llama model.
121
+ """
122
+ query_text =message
123
+
124
+ results = db.similarity_search_with_relevance_scores(query_text, k=2)
125
+ if len(results) == 0 or results[0][1] < 0.5:
126
+ print(f"Unable to find matching results.")
127
+
128
+
129
+ context_text = "\n\n---\n\n".join([doc.page_content for doc, _score in results ])
130
+
131
+ template = """
132
+ The following is a conversation between a human an AI. Answer question based only on the conversation.
133
+
134
+ Current conversation:
135
+ {history}
136
+
137
+ """
138
+
139
+
140
+
141
+ s="""
142
+
143
+ \n question: {input}
144
+
145
+ \n answer:""".strip()
146
+
147
+
148
+ prompt = PromptTemplate(input_variables=["history", "input"], template=template+context_text+'\n'+s)
149
+
150
+ #print(template)
151
+ chain.prompt=prompt
152
+ res = chain.predict(input=query_text)
153
+ return res
154
+ #return response.strip()
155
+
156
+
157
+
158
+ import gradio as gr
159
+ iface = gr.Interface(fn=get_llama_response, inputs=gr.Textbox(),
160
+ outputs="textbox")
161
+ iface.launch(share=True)
162
+
163
+
164
+