reshinthadith commited on
Commit
1019a35
1 Parent(s): 4244b29

Adding app.py

Browse files
Files changed (1) hide show
  1. app.py +97 -103
app.py CHANGED
@@ -4,106 +4,100 @@ from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline, Stopping
4
  import time
5
  import numpy as np
6
  from torch.nn import functional as F
7
- import os
8
- token_key = os.environ.get("HF_ACCESS_TOKEN")
9
-
10
- # if torch.cuda.is_available():
11
- # m = AutoModelForCausalLM.from_pretrained("stabilityai/stablelm-tuned-alpha-7b",use_auth_token=token_key, torch_dtype=torch.float16).cuda()
12
- # tok = AutoTokenizer.from_pretrained("stabilityai/stablelm-tuned-alpha-7b",use_auth_token=token_key)
13
- # else:
14
- # m = AutoModelForCausalLM.from_pretrained("stabilityai/stablelm-tuned-alpha-7b",use_auth_token=token_key, torch_dtype=torch.float16)
15
- # tok = AutoTokenizer.from_pretrained("stabilityai/stablelm-tuned-alpha-7b",use_auth_token=token_key)
16
- # generator = pipeline('text-generation', model=m, tokenizer=tok, device=0)
17
-
18
-
19
- # start_message = """<|SYSTEM|># StableAssistant
20
- # - StableAssistant is A helpful and harmless Open Source AI Language Model developed by Stability and CarperAI.
21
- # - StableAssistant is excited to be able to help the user, but will refuse to do anything that could be considered harmful to the user.
22
- # - StableAssistant is more than just an information source, StableAssistant is also able to write poetry, short stories, and make jokes.
23
- # - StableAssistant will refuse to participate in anything that could harm a human."""
24
-
25
-
26
- # class StopOnTokens(StoppingCriteria):
27
- # def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
28
- # stop_ids = [50278, 50279, 50277, 1, 0]
29
- # for stop_id in stop_ids:
30
- # if input_ids[0][-1] == stop_id:
31
- # return True
32
- # return False
33
-
34
-
35
- # def contrastive_generate(text, bad_text):
36
- # with torch.no_grad():
37
- # if torch.cuda_is_available():
38
- # tokens = tok(text, return_tensors="pt")['input_ids'].cuda()[:,:4096-1024]
39
- # bad_tokens = tok(bad_text, return_tensors="pt")['input_ids'].cuda()[:,:4096-1024]
40
- # else:
41
- # tokens = tok(text, return_tensors="pt")['input_ids'][:,:4096-1024]
42
- # bad_tokens = tok(bad_text, return_tensors="pt")['input_ids'][:,:4096-1024]
43
- # history = None
44
- # bad_history = None
45
- # curr_output = list()
46
- # for i in range(1024):
47
- # out = m(tokens, past_key_values=history, use_cache=True)
48
- # logits = out.logits
49
- # history = out.past_key_values
50
- # bad_out = m(bad_tokens, past_key_values=bad_history, use_cache=True)
51
- # bad_logits = bad_out.logits
52
- # bad_history = bad_out.past_key_values
53
- # probs = F.softmax(logits.float(), dim=-1)[0][-1].cpu()
54
- # bad_probs = F.softmax(bad_logits.float(), dim=-1)[0][-1].cpu()
55
- # logits = torch.log(probs)
56
- # bad_logits = torch.log(bad_probs)
57
- # logits[probs > 0.1] = logits[probs > 0.1] - bad_logits[probs > 0.1]
58
- # probs = F.softmax(logits)
59
- # out = int(torch.multinomial(probs, 1))
60
- # if out in [50278, 50279, 50277, 1, 0]:
61
- # break
62
- # else:
63
- # curr_output.append(out)
64
- # out = np.array([out])
65
- # tokens = torch.from_numpy(np.array([out])).to(
66
- # tokens.device)
67
- # bad_tokens = torch.from_numpy(np.array([out])).to(
68
- # tokens.device)
69
- # return tok.decode(curr_output)
70
-
71
- # def generate(text, bad_text=None):
72
- # stop = StopOnTokens()
73
- # result = generator(text, max_new_tokens=1024, num_return_sequences=1, num_beams=1, do_sample=True, temperature=1.0, top_p=0.95, top_k=1000, stopping_criteria=StoppingCriteriaList([stop]))
74
- # return result[0]["generated_text"].replace(text, "")
75
-
76
-
77
- # def user(user_message, history):
78
- # return "", history + [[user_message, ""]]
79
-
80
-
81
- # def bot(history, curr_system_message):
82
- # messages = curr_system_message + "".join(["".join(["<|USER|>"+item[0], "<|ASSISTANT|>"+item[1]]) for item in history])
83
- # output = generate(messages)
84
- # history[-1][1] = output
85
- # time.sleep(1)
86
- # return history
87
-
88
-
89
- # def system_update(msg):
90
- # global curr_system_message
91
- # curr_system_message = msg
92
-
93
-
94
- # with gr.Blocks() as demo:
95
- # gr.Markdown("###StableLM-tuned-Alpha-7B Chat")
96
- # with gr.Row():
97
- # with gr.Column():
98
- # chatbot = gr.Chatbot([])
99
- # clear = gr.Button("Clear")
100
- # with gr.Column():
101
- # system_msg = start_message#gr.Textbox(start_message, label="System Message", interactive=True)
102
- # msg = gr.Textbox(label="Chat Message")
103
-
104
- # msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then(
105
- # bot, [chatbot, system_msg], chatbot
106
- # )
107
- # system_msg.change(system_update, system_msg, None, queue=False)
108
- # clear.click(lambda: None, None, chatbot, queue=False)
109
- # demo.launch(share=True)
 
4
  import time
5
  import numpy as np
6
  from torch.nn import functional as F
7
+ import os
8
+ auth_key = os.environ["HF_ACCESS_TOKEN"]
9
+ print(f"Starting to load the model to memory")
10
+ m = AutoModelForCausalLM.from_pretrained("stabilityai/stablelm-tuned-alpha-7b",use_auth_token=auth_key, torch_dtype=torch.float16).cuda()
11
+ tok = AutoTokenizer.from_pretrained("stabilityai/stablelm-tuned-alpha-7b",use_auth_token=auth_key)
12
+ generator = pipeline('text-generation', model=m, tokenizer=tok, device=0)
13
+ print(f"Sucessfully loaded the model to the memory")
14
+
15
+ start_message = """<|SYSTEM|># StableAssistant
16
+ - StableAssistant is A helpful and harmless Open Source AI Language Model developed by Stability and CarperAI.
17
+ - StableAssistant is excited to be able to help the user, but will refuse to do anything that could be considered harmful to the user.
18
+ - StableAssistant is more than just an information source, StableAssistant is also able to write poetry, short stories, and make jokes.
19
+ - StableAssistant will refuse to participate in anything that could harm a human."""
20
+
21
+
22
+ class StopOnTokens(StoppingCriteria):
23
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
24
+ stop_ids = [50278, 50279, 50277, 1, 0]
25
+ for stop_id in stop_ids:
26
+ if input_ids[0][-1] == stop_id:
27
+ return True
28
+ return False
29
+
30
+
31
+ def contrastive_generate(text, bad_text):
32
+ with torch.no_grad():
33
+ tokens = tok(text, return_tensors="pt")['input_ids'].cuda()[:,:4096-1024]
34
+ bad_tokens = tok(bad_text, return_tensors="pt")['input_ids'].cuda()[:,:4096-1024]
35
+ history = None
36
+ bad_history = None
37
+ curr_output = list()
38
+ for i in range(1024):
39
+ out = m(tokens, past_key_values=history, use_cache=True)
40
+ logits = out.logits
41
+ history = out.past_key_values
42
+ bad_out = m(bad_tokens, past_key_values=bad_history, use_cache=True)
43
+ bad_logits = bad_out.logits
44
+ bad_history = bad_out.past_key_values
45
+ probs = F.softmax(logits.float(), dim=-1)[0][-1].cpu()
46
+ bad_probs = F.softmax(bad_logits.float(), dim=-1)[0][-1].cpu()
47
+ logits = torch.log(probs)
48
+ bad_logits = torch.log(bad_probs)
49
+ logits[probs > 0.1] = logits[probs > 0.1] - bad_logits[probs > 0.1]
50
+ probs = F.softmax(logits)
51
+ out = int(torch.multinomial(probs, 1))
52
+ if out in [50278, 50279, 50277, 1, 0]:
53
+ break
54
+ else:
55
+ curr_output.append(out)
56
+ out = np.array([out])
57
+ tokens = torch.from_numpy(np.array([out])).to(
58
+ tokens.device)
59
+ bad_tokens = torch.from_numpy(np.array([out])).to(
60
+ tokens.device)
61
+ return tok.decode(curr_output)
62
+
63
+ def generate(text, bad_text=None):
64
+ stop = StopOnTokens()
65
+ result = generator(text, max_new_tokens=1024, num_return_sequences=1, num_beams=1, do_sample=True, temperature=1.0, top_p=0.95, top_k=1000, stopping_criteria=StoppingCriteriaList([stop]))
66
+ return result[0]["generated_text"].replace(text, "")
67
+
68
+
69
+ def user(user_message, history):
70
+ return "", history + [[user_message, ""]]
71
+
72
+
73
+ def bot(history, curr_system_message):
74
+ messages = curr_system_message + "".join(["".join(["<|USER|>"+item[0], "<|ASSISTANT|>"+item[1]]) for item in history])
75
+ output = generate(messages)
76
+ history[-1][1] = output
77
+ time.sleep(1)
78
+ return history
79
+
80
+
81
+ def system_update(msg):
82
+ global curr_system_message
83
+ curr_system_message = msg
84
+
85
+ updated_system_message = ""
86
+
87
+ with gr.Blocks() as demo:
88
+ gr.Markdown("## StableLM-Tuned-Alpha-7b Chat")
89
+ with gr.Row():
90
+ chatbot = gr.Chatbot([])
91
+ clear = gr.Button("Clear Chat History")
92
+ system_msg = gr.Textbox(start_message, label="System Message", interactive=False,visible=False)
93
+ #system_msg = start_message
94
+ msg = gr.Textbox(label="Chat Message Box")
95
+
96
+ msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then(
97
+ bot, [chatbot, system_msg], chatbot
98
+ )
99
+ system_update(system_msg)
100
+ system_msg.change(system_update, system_msg, None, queue=False)
101
+ clear.click(lambda: None, None, chatbot, queue=False)
102
+
103
+ demo.launch()