yugamj commited on
Commit
3d1291e
·
verified ·
1 Parent(s): e4b5394

added app.py

Browse files
Files changed (1) hide show
  1. app.py +48 -0
app.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #Import libraries
2
+ import numpy as np
3
+ import pandas as pd
4
+ import transformers
5
+ from transformers import GPT2LMHeadModel, GPT2Tokenizer
6
+ import torch
7
+
8
+
9
+ #Load model
10
+ device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
11
+ model = torch.load('finance_chatbot_gpt2_complete_model.pt', map_location=torch.device('cpu'))
12
+ model = model.to(device)
13
+
14
+
15
+ #Get LLM tokenizer
16
+ tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
17
+ tokenizer.add_special_tokens({"pad_token": "<pad>",
18
+ "bos_token": "<startofstring>",
19
+ "eos_token": "<endofstring>"})
20
+ tokenizer.add_tokens(["<bot>:"])
21
+
22
+
23
+ #Inference function
24
+ def infer(inp, history):
25
+ inp = 'What does emotion stand for?'
26
+ inp = "<startofstring>"+inp+"<bot>:"
27
+ inp_tok = tokenizer(inp, return_tensors="pt")
28
+ X = inp_tok["input_ids"].to(device)
29
+ a = inp_tok["attention_mask"].to(device)
30
+ output = model.generate(X, attention_mask=a )
31
+ output = tokenizer.decode(output[0])
32
+ return output[len(inp):]
33
+
34
+
35
+ #Launch with gradio
36
+ gr.ChatInterface(
37
+ infer,
38
+ chatbot=gr.Chatbot(height=300),
39
+ textbox=gr.Textbox(placeholder="Type Here", container=False, scale=10),
40
+ title="Finance Chatbot Based on Rich Dad Poor Dad",
41
+ description="This Chatbot is Based on a fine-tuned version of 'GPT2'. Popular quotes of Robert Kiyosaki from his book, 'Rich Dad Poor Dad' and book summary were used for training this model.",
42
+ theme="soft",
43
+ examples=["What do you want to earn more passive income?", "What is the result of people working all their lives for someone else?", "What tells the story of how a person handles money?"],
44
+ cache_examples=True,
45
+ retry_btn=None,
46
+ undo_btn="Delete Previous",
47
+ clear_btn="Clear",
48
+ ).launch()