juewang commited on
Commit
0d215ca
1 Parent(s): 2bb3e31

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +67 -0
app.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import requests
3
+ import time
4
+
5
+ def infer(prompt, max_new_tokens=10, temperature=0.0, top_p=1.0):
6
+
7
+ my_post_dict = {
8
+ "type": "general",
9
+ "payload": {
10
+ "max_tokens": max_new_tokens,
11
+ "n": 1,
12
+ "temperature": float(temperature),
13
+ "top_p": float(top_p),
14
+ "model": "Together-gpt-J-6B-ProxAdam-50x",
15
+ "prompt": [prompt],
16
+ "request_type": "language-model-inference",
17
+ "stop": None,
18
+ "best_of": 1,
19
+ "echo": False,
20
+ "seed": 42,
21
+ "prompt_embedding": False,
22
+ },
23
+ "returned_payload": {},
24
+ "status": "submitted",
25
+ "source": "dalle",
26
+ }
27
+
28
+ res = requests.post("https://planetd.shift.ml/jobs", json=my_post_dict).json()
29
+
30
+ job_id = res['id']
31
+
32
+ while True:
33
+
34
+ ret = requests.get(f"https://planetd.shift.ml/job/{job_id}", json={'id': job_id}).json()
35
+
36
+ if ret['status'] == 'finished':
37
+ break
38
+
39
+ time.sleep(1)
40
+
41
+ return ret['returned_payload']['result']['inference_result'][0]['choices'][0]['text']
42
+
43
+
44
+ st.title("TOMA Application")
45
+
46
+ s_example = "Please answer the following question:\n\nQuestion: Where is Zurich?\nAnswer:"
47
+ prompt = st.text_area(
48
+ "Prompt",
49
+ value=s_example,
50
+ max_chars=1000,
51
+ height=400,
52
+ )
53
+
54
+
55
+ generated_area = st.empty()
56
+ generated_area.markdown("(Generate here)")
57
+
58
+ button_submit = st.button("Submit")
59
+
60
+ max_new_tokens = st.number_input('Max new tokens', 1, 1024, 10)
61
+ temperature = st.number_input('temperature', 0.0, 10.0, 0.0, step=0.1, format="%.2f")
62
+ top_p = st.number_input('top_p', 0.0, 1.0, 1.0, step=0.1, format="%.2f")
63
+
64
+ if button_submit:
65
+ with st.spinner(text="In progress.."):
66
+ report_text = infer(prompt, max_new_tokens=max_new_tokens, temperature=temperature, top_p=top_p)
67
+ generated_area.markdown(report_text)