github-actions commited on
Commit
8ca1a6a
·
1 Parent(s): eaf872d

Sync updates from source repository

Browse files
Files changed (3) hide show
  1. app.py +84 -5
  2. query.py +8 -7
  3. requirements.txt +2 -0
app.py CHANGED
@@ -1,13 +1,59 @@
1
  from omegaconf import OmegaConf
2
  from query import VectaraQuery
3
  import os
 
 
 
4
 
5
  import streamlit as st
6
  from streamlit_pills import pills
 
7
 
8
  from PIL import Image
9
 
10
  max_examples = 6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
  def isTrue(x) -> bool:
13
  if isinstance(x, bool):
@@ -16,11 +62,11 @@ def isTrue(x) -> bool:
16
 
17
  def launch_bot():
18
  def generate_response(question):
19
- response = vq.submit_query(question)
20
  return response
21
 
22
  def generate_streaming_response(question):
23
- response = vq.submit_query_streaming(question)
24
  return response
25
 
26
  def show_example_questions():
@@ -41,11 +87,13 @@ def launch_bot():
41
  'source_data_desc': os.environ['source_data_desc'],
42
  'streaming': isTrue(os.environ.get('streaming', False)),
43
  'prompt_name': os.environ.get('prompt_name', None),
44
- 'examples': os.environ.get('examples', None)
 
45
  })
46
  st.session_state.cfg = cfg
47
  st.session_state.ex_prompt = None
48
- st.session_state.first_turn = True
 
49
  example_messages = [example.strip() for example in cfg.examples.split(",")]
50
  st.session_state.example_messages = [em for em in example_messages if len(em)>0][:max_examples]
51
 
@@ -60,7 +108,13 @@ def launch_bot():
60
  image = Image.open('Vectara-logo.png')
61
  st.image(image, width=175)
62
  st.markdown(f"## About\n\n"
63
- f"This demo uses Retrieval Augmented Generation to ask questions about {cfg.source_data_desc}\n\n")
 
 
 
 
 
 
64
 
65
  st.markdown("---")
66
  st.markdown(
@@ -111,7 +165,32 @@ def launch_bot():
111
  st.write(response)
112
  message = {"role": "assistant", "content": response}
113
  st.session_state.messages.append(message)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
114
  st.rerun()
 
 
 
 
 
 
115
 
116
  if __name__ == "__main__":
117
  launch_bot()
 
1
  from omegaconf import OmegaConf
2
  from query import VectaraQuery
3
  import os
4
+ import requests
5
+ import json
6
+ import uuid
7
 
8
  import streamlit as st
9
  from streamlit_pills import pills
10
+ from streamlit_feedback import streamlit_feedback
11
 
12
  from PIL import Image
13
 
14
  max_examples = 6
15
+ languages = {'English': 'eng', 'Spanish': 'spa', 'French': 'frs', 'Chinese': 'zho', 'German': 'deu', 'Hindi': 'hin', 'Arabic': 'ara',
16
+ 'Portuguese': 'por', 'Italian': 'ita', 'Japanese': 'jpn', 'Korean': 'kor', 'Russian': 'rus', 'Turkish': 'tur', 'Persian (Farsi)': 'fas',
17
+ 'Vietnamese': 'vie', 'Thai': 'tha', 'Hebrew': 'heb', 'Dutch': 'nld', 'Indonesian': 'ind', 'Polish': 'pol', 'Ukrainian': 'ukr',
18
+ 'Romanian': 'ron', 'Swedish': 'swe', 'Czech': 'ces', 'Greek': 'ell', 'Bengali': 'ben', 'Malay (or Malaysian)': 'msa', 'Urdu': 'urd'}
19
+
20
+ # Setup for HTTP API Calls to Amplitude Analytics
21
+ if 'device_id' not in st.session_state:
22
+ st.session_state.device_id = str(uuid.uuid4())
23
+
24
+ headers = {
25
+ 'Content-Type': 'application/json',
26
+ 'Accept': '*/*'
27
+ }
28
+ amp_api_key = os.getenv('AMPLITUDE_TOKEN')
29
+
30
+ def thumbs_feedback(feedback, **kwargs):
31
+ """
32
+ Sends feedback to Amplitude Analytics
33
+ """
34
+ data = {
35
+ "api_key": amp_api_key,
36
+ "events": [{
37
+ "device_id": st.session_state.device_id,
38
+ "event_type": "provided_feedback",
39
+ "event_properties": {
40
+ "Space Name": kwargs.get("title", "Unknown Space Name"),
41
+ "Demo Type": "chatbot",
42
+ "query": kwargs.get("prompt", "No user input"),
43
+ "response": kwargs.get("response", "No chat response"),
44
+ "feedback": feedback["score"],
45
+ "Response Language": st.session_state.language
46
+ }
47
+ }]
48
+ }
49
+ response = requests.post('https://api2.amplitude.com/2/httpapi', headers=headers, data=json.dumps(data))
50
+ if response.status_code != 200:
51
+ print(f"Request failed with status code {response.status_code}. Response Text: {response.text}")
52
+
53
+ st.session_state.feedback_key += 1
54
+
55
+ if "feedback_key" not in st.session_state:
56
+ st.session_state.feedback_key = 0
57
 
58
  def isTrue(x) -> bool:
59
  if isinstance(x, bool):
 
62
 
63
  def launch_bot():
64
  def generate_response(question):
65
+ response = vq.submit_query(question, languages[st.session_state.language])
66
  return response
67
 
68
  def generate_streaming_response(question):
69
+ response = vq.submit_query_streaming(question, languages[st.session_state.language])
70
  return response
71
 
72
  def show_example_questions():
 
87
  'source_data_desc': os.environ['source_data_desc'],
88
  'streaming': isTrue(os.environ.get('streaming', False)),
89
  'prompt_name': os.environ.get('prompt_name', None),
90
+ 'examples': os.environ.get('examples', None),
91
+ 'language': 'English'
92
  })
93
  st.session_state.cfg = cfg
94
  st.session_state.ex_prompt = None
95
+ st.session_state.first_turn = True
96
+ st.session_state.language = cfg.language
97
  example_messages = [example.strip() for example in cfg.examples.split(",")]
98
  st.session_state.example_messages = [em for em in example_messages if len(em)>0][:max_examples]
99
 
 
108
  image = Image.open('Vectara-logo.png')
109
  st.image(image, width=175)
110
  st.markdown(f"## About\n\n"
111
+ f"This demo uses Retrieval Augmented Generation to ask questions about {cfg.source_data_desc}\n")
112
+
113
+ cfg.language = st.selectbox('Language:', languages.keys())
114
+ if st.session_state.language != cfg.language:
115
+ st.session_state.language = cfg.language
116
+ print(f"DEBUG: Language changed to {st.session_state.language}")
117
+ st.rerun()
118
 
119
  st.markdown("---")
120
  st.markdown(
 
165
  st.write(response)
166
  message = {"role": "assistant", "content": response}
167
  st.session_state.messages.append(message)
168
+
169
+ # Send query and response to Amplitude Analytics
170
+ data = {
171
+ "api_key": amp_api_key,
172
+ "events": [{
173
+ "device_id": st.session_state.device_id,
174
+ "event_type": "submitted_query",
175
+ "event_properties": {
176
+ "Space Name": cfg["title"],
177
+ "Demo Type": "chatbot",
178
+ "query": st.session_state.messages[-2]["content"],
179
+ "response": st.session_state.messages[-1]["content"],
180
+ "Response Language": st.session_state.language
181
+ }
182
+ }]
183
+ }
184
+ response = requests.post('https://api2.amplitude.com/2/httpapi', headers=headers, data=json.dumps(data))
185
+ if response.status_code != 200:
186
+ print(f"Amplitude request failed with status code {response.status_code}. Response Text: {response.text}")
187
  st.rerun()
188
+
189
+ if (st.session_state.messages[-1]["role"] == "assistant") & (st.session_state.messages[-1]["content"] != "How may I help you?"):
190
+ streamlit_feedback(feedback_type="thumbs", on_submit = thumbs_feedback, key = st.session_state.feedback_key,
191
+ kwargs = {"prompt": st.session_state.messages[-2]["content"],
192
+ "response": st.session_state.messages[-1]["content"],
193
+ "title": cfg["title"]})
194
 
195
  if __name__ == "__main__":
196
  launch_bot()
query.py CHANGED
@@ -10,7 +10,7 @@ class VectaraQuery():
10
  self.conv_id = None
11
 
12
 
13
- def get_body(self, query_str: str, stream: False):
14
  corpora_list = [{
15
  'corpus_key': corpus_key, 'lexical_interpolation': 0.005
16
  } for corpus_key in self.corpus_keys
@@ -40,11 +40,12 @@ class VectaraQuery():
40
  {
41
  'prompt_name': self.prompt_name,
42
  'max_used_search_results': 10,
43
- 'response_language': 'eng',
44
  'citations':
45
  {
46
  'style': 'none'
47
- }
 
48
  },
49
  'chat':
50
  {
@@ -70,14 +71,14 @@ class VectaraQuery():
70
  "grpc-timeout": "60S"
71
  }
72
 
73
- def submit_query(self, query_str: str):
74
 
75
  if self.conv_id:
76
  endpoint = f"https://api.vectara.io/v2/chats/{self.conv_id}/turns"
77
  else:
78
  endpoint = "https://api.vectara.io/v2/chats"
79
 
80
- body = self.get_body(query_str, stream=False)
81
 
82
  response = requests.post(endpoint, data=json.dumps(body), verify=True, headers=self.get_headers())
83
 
@@ -96,14 +97,14 @@ class VectaraQuery():
96
 
97
  return summary
98
 
99
- def submit_query_streaming(self, query_str: str):
100
 
101
  if self.conv_id:
102
  endpoint = f"https://api.vectara.io/v2/chats/{self.conv_id}/turns"
103
  else:
104
  endpoint = "https://api.vectara.io/v2/chats"
105
 
106
- body = self.get_body(query_str, stream=True)
107
 
108
  response = requests.post(endpoint, data=json.dumps(body), verify=True, headers=self.get_stream_headers(), stream=True)
109
 
 
10
  self.conv_id = None
11
 
12
 
13
+ def get_body(self, query_str: str, response_lang: str, stream: False):
14
  corpora_list = [{
15
  'corpus_key': corpus_key, 'lexical_interpolation': 0.005
16
  } for corpus_key in self.corpus_keys
 
40
  {
41
  'prompt_name': self.prompt_name,
42
  'max_used_search_results': 10,
43
+ 'response_language': response_lang,
44
  'citations':
45
  {
46
  'style': 'none'
47
+ },
48
+ 'enable_factual_consistency_score': False
49
  },
50
  'chat':
51
  {
 
71
  "grpc-timeout": "60S"
72
  }
73
 
74
+ def submit_query(self, query_str: str, language: str):
75
 
76
  if self.conv_id:
77
  endpoint = f"https://api.vectara.io/v2/chats/{self.conv_id}/turns"
78
  else:
79
  endpoint = "https://api.vectara.io/v2/chats"
80
 
81
+ body = self.get_body(query_str, language, stream=False)
82
 
83
  response = requests.post(endpoint, data=json.dumps(body), verify=True, headers=self.get_headers())
84
 
 
97
 
98
  return summary
99
 
100
+ def submit_query_streaming(self, query_str: str, language: str):
101
 
102
  if self.conv_id:
103
  endpoint = f"https://api.vectara.io/v2/chats/{self.conv_id}/turns"
104
  else:
105
  endpoint = "https://api.vectara.io/v2/chats"
106
 
107
+ body = self.get_body(query_str, language, stream=True)
108
 
109
  response = requests.post(endpoint, data=json.dumps(body), verify=True, headers=self.get_stream_headers(), stream=True)
110
 
requirements.txt CHANGED
@@ -3,3 +3,5 @@ toml==0.10.2
3
  omegaconf==2.3.0
4
  syrupy==4.0.8
5
  streamlit_pills==0.3.0
 
 
 
3
  omegaconf==2.3.0
4
  syrupy==4.0.8
5
  streamlit_pills==0.3.0
6
+ streamlit-feedback==0.1.3
7
+ uuid==1.30