Ngadou commited on
Commit
9ba218c
·
1 Parent(s): c709e83

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +44 -39
app.py CHANGED
@@ -1,9 +1,9 @@
1
  import gradio as gr
2
- import time
3
- import torch
4
-
5
  from peft import PeftModel, PeftConfig
6
  from transformers import AutoModelForCausalLM, AutoTokenizer
 
 
7
 
8
  peft_model_id = "Ngadou/falcon-7b-scam-buster"
9
  config = PeftConfig.from_pretrained(peft_model_id)
@@ -11,37 +11,8 @@ config = PeftConfig.from_pretrained(peft_model_id)
11
  model = AutoModelForCausalLM.from_pretrained(config.base_model_name_or_path, trust_remote_code=True, return_dict=True, load_in_4bit=True, device_map='auto')
12
  tokenizer = AutoTokenizer.from_pretrained(config.base_model_name_or_path)
13
 
 
14
  model = PeftModel.from_pretrained(model, peft_model_id).to("cuda")
15
-
16
-
17
-
18
- # Load the Lora model
19
- model = PeftModel.from_pretrained(model, peft_model_id)
20
-
21
- tokenizer.pad_token = tokenizer.eos_token
22
-
23
- def generate(chat):
24
-
25
- input_text = chat + "\nIs this conversation a scam or not and why?"
26
-
27
-
28
- encoding = tokenizer(input_text, return_tensors="pt").to("cuda")
29
- output = model.generate(
30
- input_ids=encoding.input_ids,
31
- attention_mask=encoding.attention_mask,
32
- max_new_tokens=100,
33
- do_sample=True,
34
- temperature=0.000001,
35
- eos_token_id=tokenizer.eos_token_id,
36
- top_k = 0
37
- )
38
-
39
- output_text = tokenizer.decode(output[0], skip_special_tokens=True)
40
- output_text = output_text.replace(example_text, "").lstrip("\n")
41
-
42
- print("\nAnswer:")
43
- print(output_text)
44
- return output_text
45
 
46
 
47
  # def is_scam(instruction):
@@ -83,11 +54,45 @@ def generate(chat):
83
  # return classification #, reason
84
 
85
 
86
- # Define the Gradio interface
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
  gr.Interface(
88
- fn=generate,
89
  inputs='text',
90
- outputs=[
91
- gr.outputs.Textbox(label="Classification and rational")
92
- ]
93
- ).launch()
 
1
  import gradio as gr
2
+ from gradio.components import Textbox
 
 
3
  from peft import PeftModel, PeftConfig
4
  from transformers import AutoModelForCausalLM, AutoTokenizer
5
+ from transformers import GenerationConfig
6
+
7
 
8
  peft_model_id = "Ngadou/falcon-7b-scam-buster"
9
  config = PeftConfig.from_pretrained(peft_model_id)
 
11
  model = AutoModelForCausalLM.from_pretrained(config.base_model_name_or_path, trust_remote_code=True, return_dict=True, load_in_4bit=True, device_map='auto')
12
  tokenizer = AutoTokenizer.from_pretrained(config.base_model_name_or_path)
13
 
14
+ # Adapter model
15
  model = PeftModel.from_pretrained(model, peft_model_id).to("cuda")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
 
17
 
18
  # def is_scam(instruction):
 
54
  # return classification #, reason
55
 
56
 
57
+ def is_scam(instruction):
58
+ max_new_tokens=128
59
+ temperature=0.1
60
+ top_p=0.75
61
+ top_k=40
62
+ num_beams=4
63
+
64
+ instruction = instruction + ".\nIs this conversation a scam or not and why?"
65
+ prompt = instruction + "\n### Solution:\n"
66
+ inputs = tokenizer(prompt, return_tensors="pt")
67
+ input_ids = inputs["input_ids"].to("cuda")
68
+ attention_mask = inputs["attention_mask"].to("cuda")
69
+ generation_config = GenerationConfig(
70
+ temperature=temperature,
71
+ top_p=top_p,
72
+ top_k=top_k,
73
+ num_beams=num_beams,
74
+ )
75
+ with torch.no_grad():
76
+ generation_output = model.generate(
77
+ input_ids=input_ids,
78
+ attention_mask=attention_mask,
79
+ generation_config=generation_config,
80
+ return_dict_in_generate=True,
81
+ output_scores=True,
82
+ max_new_tokens=max_new_tokens,
83
+ early_stopping=True
84
+ )
85
+ s = generation_output.sequences[0]
86
+ output = tokenizer.decode(s)
87
+
88
+ classification = output.split("### Solution:")[1].lstrip("\n")
89
+ print(classification)
90
+
91
+ return str(classification), "Hello World"
92
+
93
+
94
  gr.Interface(
95
+ fn=is_scam,
96
  inputs='text',
97
+ outputs= ['text','text']
98
+ ).launch(share=True, debug=True)