arampacha commited on
Commit
c6b7aba
·
1 Parent(s): 7579cb6
Files changed (1) hide show
  1. app.py +8 -5
app.py CHANGED
@@ -9,7 +9,6 @@ API_TOKEN = st.secrets["hf_api_token"]
9
  headers = {"Authorization": f"Bearer {API_TOKEN}"}
10
  API_URL = "https://api-inference.huggingface.co/models/arampacha/DialoGPT-medium-simpsons"
11
 
12
-
13
  def query(payload):
14
  data = json.dumps(payload)
15
  response = requests.request("POST", API_URL, headers=headers, data=data)
@@ -29,7 +28,7 @@ def fake_query(payload):
29
 
30
  parameters = {
31
  "min_length":None,
32
- "max_length":50,
33
  "top_p":0.92,
34
  "repetition_penalty":None,
35
  }
@@ -54,8 +53,12 @@ def on_input():
54
  }
55
  # result = fake_query(payload)
56
  result = query(payload)
57
- st.session_state.update(result["conversation"])
58
- st.session_state.full_text += f'_Chatbot_ > {result["generated_text"]}\n\n'
 
 
 
 
59
  st.session_state.count += 1
60
 
61
 
@@ -100,4 +103,4 @@ dialog_output.markdown(dialog_text)
100
  def restart():
101
  st.session_state.clear()
102
 
103
- st.button("Restart", on_click=st.session_state.clear)
 
9
  headers = {"Authorization": f"Bearer {API_TOKEN}"}
10
  API_URL = "https://api-inference.huggingface.co/models/arampacha/DialoGPT-medium-simpsons"
11
 
 
12
  def query(payload):
13
  data = json.dumps(payload)
14
  response = requests.request("POST", API_URL, headers=headers, data=data)
 
28
 
29
  parameters = {
30
  "min_length":None,
31
+ "max_length":100,
32
  "top_p":0.92,
33
  "repetition_penalty":None,
34
  }
 
53
  }
54
  # result = fake_query(payload)
55
  result = query(payload)
56
+ try:
57
+ st.session_state.update(result["conversation"])
58
+ st.session_state.full_text += f'_Chatbot_ > {result["generated_text"]}\n\n'
59
+ except:
60
+ st.write("D'oh! Something went wrong. Try to rerun the app.")
61
+ st.write(result)
62
  st.session_state.count += 1
63
 
64
 
 
103
  def restart():
104
  st.session_state.clear()
105
 
106
+ st.button("Restart", on_click=st.session_state.clear)