File size: 2,015 Bytes
0d215ca
 
 
 
cc85063
0d215ca
3442116
 
 
 
0d215ca
 
 
bbe538b
0d215ca
 
 
3442116
0d215ca
 
 
 
 
 
 
 
 
 
 
 
 
0b3be54
0d215ca
bbe538b
0b3be54
 
0d215ca
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3442116
0b3be54
 
 
0d215ca
 
 
cc85063
0d215ca
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import streamlit as st
import requests
import time

def infer(prompt, model_name, max_new_tokens=10, temperature=0.0, top_p=1.0):

    model_name_map = {
        "GPT-JT-6B-v1": "Together-gpt-JT-6B-v1",
    }

    my_post_dict = {
        "type": "general",
        "payload": {
            "max_tokens": int(max_new_tokens),
            "n": 1,
            "temperature": float(temperature),
            "top_p": float(top_p),
            "model": model_name_map[model_name],
            "prompt": [prompt],
            "request_type": "language-model-inference",
            "stop": None,
            "best_of": 1,
            "echo": False,
            "seed": 42,
            "prompt_embedding": False,
        },
        "returned_payload": {},
        "status": "submitted",
        "source": "dalle",
    }
    
    job_id = requests.post("https://planetd.shift.ml/jobs", json=my_post_dict).json()['id']
    
    for i in range(100):
    
        time.sleep(1)
        
        ret = requests.get(f"https://planetd.shift.ml/job/{job_id}", json={'id': job_id}).json()
        
        if ret['status'] == 'finished':
            break
        
    return ret['returned_payload']['result']['inference_result'][0]['choices'][0]['text']
    
    
st.title("TOMA Application")
 
s_example = "Please answer the following question:\n\nQuestion: Where is Zurich?\nAnswer:"
prompt = st.text_area(
    "Prompt",
    value=s_example,
    max_chars=1000,
    height=400,
)
    
generated_area = st.empty()
generated_area.markdown("(Generate here)")

button_submit = st.button("Submit")
   
model_name = st.selectbox("Model", ["GPT-JT-6B-v1"])
max_new_tokens = st.text_input('Max new tokens', "10")
temperature = st.text_input('temperature', "0.0")
top_p = st.text_input('top_p', "1.0")

if button_submit:
    with st.spinner(text="In progress.."):
        report_text = infer(prompt, model_name=model_name, max_new_tokens=max_new_tokens, temperature=temperature, top_p=top_p)
        generated_area.markdown(report_text)