umuthopeyildirim commited on
Commit
c8255b3
·
verified ·
1 Parent(s): f27d7e4

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +98 -0
app.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This is a Hugging Face Spaces demo for Fin-RWKV-1B5 attention free finanacial export modal.
3
+ Author: Umut (Hope) YILDIRIM <[email protected]>
4
+ """
5
+
6
+ import gradio as gr
7
+ from transformers import AutoTokenizer, AutoModelForCausalLM, StoppingCriteria, StoppingCriteriaList, TextIteratorStreamer
8
+ from threading import Thread
9
+ import torch
10
+
11
+ tokenizer = AutoTokenizer.from_pretrained("umuthopeyildirim/fin-rwkv-1b5")
12
+ model = AutoModelForCausalLM.from_pretrained("umuthopeyildirim/fin-rwkv-1b5")
13
+
14
+
15
+ class StopOnTokens(StoppingCriteria):
16
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
17
+ stop_ids = [29, 0]
18
+ for stop_id in stop_ids:
19
+ if input_ids[0][-1] == stop_id:
20
+ return True
21
+ return False
22
+
23
+
24
+ def predict(message, history):
25
+
26
+ history_transformer_format = history + [[message, ""]]
27
+ stop = StopOnTokens()
28
+
29
+ messages = "".join(["".join(["\nuser :"+item[0], "\nbot:"+item[1]]) # curr_system_message +
30
+ for item in history_transformer_format])
31
+
32
+ print(messages)
33
+
34
+ model_inputs = tokenizer([messages], return_tensors="pt")
35
+ streamer = TextIteratorStreamer(
36
+ tokenizer, timeout=10., skip_prompt=True, skip_special_tokens=True)
37
+ generate_kwargs = dict(
38
+ model_inputs,
39
+ streamer=streamer,
40
+ max_new_tokens=1024,
41
+ do_sample=True,
42
+ top_p=0.95,
43
+ top_k=1000,
44
+ temperature=0.5,
45
+ num_beams=1,
46
+ stopping_criteria=StoppingCriteriaList([stop])
47
+ )
48
+ t = Thread(target=model.generate, kwargs=generate_kwargs)
49
+ t.start()
50
+
51
+ partial_message = ""
52
+ for new_token in streamer:
53
+ if new_token != '<':
54
+ partial_message += new_token
55
+ yield partial_message
56
+
57
+
58
+ def generate_text(prompt, tokenizer, model):
59
+ # Tokenize the input
60
+ input_ids = tokenizer.encode(prompt, return_tensors="pt")
61
+
62
+ # Generate a response
63
+ output = model.generate(input_ids, max_length=333, num_return_sequences=1)
64
+
65
+ # Decode the output
66
+ generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
67
+
68
+ return generated_text
69
+
70
+
71
+ title = "# Fin-RWKV: Attention Free Financal Expert (WIP)"
72
+ description = """Demo for **Fin-RWKV: Attention Free Financal Expert (WIP)**.
73
+ To download the model, please visit [Fin-RWKV: Attention Free Financal Expert (WIP)](https://huggingface.co/umutyildirim/fin-rwkv-1b5)."""
74
+
75
+ css = """
76
+ #img-display-container {
77
+ max-height: 100vh;
78
+ }
79
+ #img-display-input {
80
+ max-height: 80vh;
81
+ }
82
+ #img-display-output {
83
+ max-height: 80vh;
84
+ }
85
+ """
86
+
87
+ with gr.Blocks(css=css) as demo:
88
+ gr.Markdown(title)
89
+ gr.Markdown(description)
90
+ with gr.Tab("Chatbot"):
91
+ gr.ChatInterface(predict)
92
+ with gr.Tab("E-Commerce"):
93
+ gr.Markdown("e-commerce")
94
+ with gr.Tab("OpenBB"):
95
+ gr.Markdown("openbb")
96
+
97
+ if __name__ == '__main__':
98
+ demo.queue().launch()