RaushanTurganbay HF staff commited on
Commit
37cd4b3
1 Parent(s): 9b91ee0
Files changed (1) hide show
  1. app.py +33 -36
app.py CHANGED
@@ -32,14 +32,12 @@ bnb_config = BitsAndBytesConfig(
32
  )
33
 
34
  tokenizer = AutoTokenizer.from_pretrained("bigscience/bloomz-7b1")
35
- model = AutoModelForCausalLM.from_pretrained("bigscience/bloomz-7b1", device_map="auto")
36
- if torch.__version__ >= "2":
37
- model = torch.compile(model)
38
- print(f"Successfully loaded the model {model_name} into memory")
39
 
40
 
41
- # Define stopping criteria
42
- stop_tokens = ["###", "Human", "\n###"]
43
  stop_token_ids = tokenizer.convert_tokens_to_ids(stop_tokens)
44
 
45
  class StopOnTokens(StoppingCriteria):
@@ -50,43 +48,42 @@ class StopOnTokens(StoppingCriteria):
50
  return False
51
 
52
 
53
-
54
- #Prompts
55
  instruction_with_q = """
56
- A chat between a curious human and an artificial intelligence assistant.
57
- The assistant's job is to answer the given question using only the information provided in the RDF triplet format. The assistant's answer should be in a human-readable format, with proper sentences and grammar and should be concise and short.
58
  The RDF triplets will be provided in triplets, where triplets are always in the (subject, relation, object) format and are separated by a semicolon. The assistant should understand that if multiple triplets are provided, the answer to the question should use all of the information from triplets and make aggregation. The assistant MUST NOT add any additional information, beside form the one proveded in the triplets.
59
  The assistant should try to reply as short as possible, and perform counting or aggregation operations over triplets by himself when necessary.
60
  """
61
 
62
  instruction_wo_q = """
63
- A chat between a curious human and an artificial intelligence assistant.
64
- The assistant's job is convert the provided input in RDF triplet format into human-readable text format, with proper sentences and grammar. The triplets are always in the (subject, relation, object) format, where each triplet is separated by a semicolon. The assistant should understand that if multiple triplets are provided, the generated human-readable text should use all of the information from input. The assistant MUST NOT add any additional information, beside form the one proveded in the input.
65
  """
66
 
67
 
68
  history_with_q = [
69
- ("Human", "Question: Is Essex the Ceremonial County of West Tilbury? Triplets: ('West Tilbury', 'Ceremonial County', 'Essex');"),
70
- ("Assistant", "Essex is the Ceremonial County of West Tributary"),
71
- ("Human", "Question: What nation is Hornito located in, where Jamie Bateman Cayn died too? Triplets: ('Jaime Bateman Cay贸n', 'death place', 'Panama'); ('Hornito, Chiriqu铆', 'country', 'Panama');"),
72
- ("Assistant", "Hornito, Chiriqu铆 is located in Panama, where Jaime Bateman Cay贸n died."),
73
- ("Human", "Question: Who are the shareholder of the soccer club for whom Steve Holland plays? Triplets: ('Steve Holland', 'current club', 'Chelsea F.C.'); ('Chelsea F.C.', 'owner', 'Roman Abramovich');"),
74
- ("Assistant", "Roman Abramovich owns Chelsea F.C., where Steve Holland plays."),
75
- ("Human", "Question: Who is the chancellor of Falmouth University? Triplets: ('Falmouth University', 'chancellor', 'Dawn French');"),
76
- ("Assistant", "The chancellor of the Falmouth University is Dawn French.")
77
 
78
  ]
79
 
80
 
81
  history_wo_q = [
82
- ("Human", "('West Tilbury', 'Ceremonial County', 'Essex');"),
83
- ("Assistant", "Essex is the Ceremonial County of West Tributary"),
84
- ("Human", "('Jaime Bateman Cay贸n', 'death place', 'Panama'); ('Hornito, Chiriqu铆', 'country', 'Panama');"),
85
- ("Assistant", "Hornito, Chiriqu铆 is located in Panama, where Jaime Bateman Cay贸n died."),
86
- ("Human", "('Steve Holland', 'current club', 'Chelsea F.C.'); ('Chelsea F.C.', 'owner', 'Roman Abramovich');"),
87
- ("Assistant", "Roman Abramovich owns Chelsea F.C., where Steve Holland plays."),
88
- ("Human", "('Falmouth University', 'chancellor', 'Dawn French');"),
89
- ("Assistant", "The chancellor of the Falmouth University is Dawn French.")
90
 
91
  ]
92
 
@@ -95,7 +92,7 @@ history_wo_q = [
95
  def prepare_input(linearized_triplets, question=None) -> str:
96
  if question and "List all" in question:
97
  question = question.replace("List all ", "Which are ")
98
- if "question" in style:
99
  input_text = f"Question: {question.strip()} Triplets: {linearized_triplets}"
100
  else:
101
  input_text = linearized_triplets
@@ -107,10 +104,10 @@ def make_prompt(
107
  instruction: str,
108
  history: List[Tuple[str, str]]=None,
109
  ) -> str:
110
- ret = f"{instruction}\n###"
111
  for i, (role, message) in enumerate(history):
112
- ret += f"{role}: {message}\n###"
113
- ret += f"Human: {curr_input}\n###Assistant: \n"
114
  return ret
115
 
116
 
@@ -128,6 +125,7 @@ def generate_output(
128
  else:
129
  instruction = make_prompt(curr_input, instruction_wo_q, history_wo_q)
130
 
 
131
  input_ids = tokenizer(instruction, return_tensors="pt").input_ids
132
  input_ids = input_ids.to(model.device)
133
 
@@ -139,11 +137,10 @@ def generate_output(
139
  top_p=top_p,
140
  top_k=top_k,
141
  repetition_penalty=repetition_penalty,
142
- stopping_criteria=StoppingCriteriaList([stop]),
143
  )
144
 
145
  with torch.no_grad():
146
- outputs = model.generate(generate_kwargs)
147
 
148
  response = tokenizer.decode(outputs.sequences[0], skip_special_tokens=True)
149
  for tok in tokenizer.additional_special_tokens+[tokenizer.eos_token]:
@@ -186,7 +183,7 @@ with gr.Blocks(theme='gradio/soft') as demo:
186
  examples = gr.Examples([
187
  ['("Google Videos", "developer", "Google"), ("Google Web Toolkit", "author", "Google")', ""],
188
  ['("Katyayana", "religion", "Buddhism")', "What is the relegious affiliations of Katyayana?"],
189
- ], inputs=[triplets, question, temperature, top_p, top_k, repetition_penalty], fn=generate, cache_examples=False if platform.system() == "Windows" or platform.system() == "Darwin" else True, outputs=output_box)
190
 
191
 
192
  #readme_content = requests.get(f"https://huggingface.co/HF_MODEL_PATH/raw/main/README.md").text
@@ -197,7 +194,7 @@ with gr.Blocks(theme='gradio/soft') as demo:
197
  # readme_content,
198
  # )
199
 
200
- run_button.click(fn=generate, inputs=[triplets, question, temperature, top_p, top_k, repetition_penalty], outputs=output_box, api_name="rdf2text")
201
  clear_button.add([triplets, question, output_box])
202
 
203
  demo.queue(concurrency_count=1, max_size=10).launch(debug=True)
 
32
  )
33
 
34
  tokenizer = AutoTokenizer.from_pretrained("bigscience/bloomz-7b1")
35
+ model = AutoModelForCausalLM.from_pretrained("bigscience/bloomz-7b1", quantization_config=bnb_config, trust_remote_code=True, device_map="auto")
36
+ print(f"Successfully loaded the model into memory")
 
 
37
 
38
 
39
+ # Define stopping criteria. We do not use it for bloom model family but it can be used for llama model family
40
+ stop_tokens = ["\n###"]
41
  stop_token_ids = tokenizer.convert_tokens_to_ids(stop_tokens)
42
 
43
  class StopOnTokens(StoppingCriteria):
 
48
  return False
49
 
50
 
51
+ # Prompts
 
52
  instruction_with_q = """
53
+ A chat between a curious USER and an artificial intelligence assistant.
54
+ The assistant's job is to answer the given question using only the information provided in the RDF triplet format. The assistant's answer should be in a USER-readable format, with proper sentences and grammar and should be concise and short.
55
  The RDF triplets will be provided in triplets, where triplets are always in the (subject, relation, object) format and are separated by a semicolon. The assistant should understand that if multiple triplets are provided, the answer to the question should use all of the information from triplets and make aggregation. The assistant MUST NOT add any additional information, beside form the one proveded in the triplets.
56
  The assistant should try to reply as short as possible, and perform counting or aggregation operations over triplets by himself when necessary.
57
  """
58
 
59
  instruction_wo_q = """
60
+ A chat between a curious USER and an artificial intelligence assistant.
61
+ The assistant's job is convert the provided input in RDF triplet format into USER-readable text format, with proper sentences and grammar. The triplets are always in the (subject, relation, object) format, where each triplet is separated by a semicolon. The assistant should understand that if multiple triplets are provided, the generated USER-readable text should use all of the information from input. The assistant MUST NOT add any additional information, beside form the one proveded in the input.
62
  """
63
 
64
 
65
  history_with_q = [
66
+ ("USER", "Question: Is Essex the Ceremonial County of West Tilbury? Triplets: ('West Tilbury', 'Ceremonial County', 'Essex');"),
67
+ ("ASSISTANT", "Essex is the Ceremonial County of West Tributary"),
68
+ ("USER", "Question: What nation is Hornito located in, where Jamie Bateman Cayn died too? Triplets: ('Jaime Bateman Cay贸n', 'death place', 'Panama'); ('Hornito, Chiriqu铆', 'country', 'Panama');"),
69
+ ("ASSISTANT", "Hornito, Chiriqu铆 is located in Panama, where Jaime Bateman Cay贸n died."),
70
+ ("USER", "Question: Who are the shareholder of the soccer club for whom Steve Holland plays? Triplets: ('Steve Holland', 'current club', 'Chelsea F.C.'); ('Chelsea F.C.', 'owner', 'Roman Abramovich');"),
71
+ ("ASSISTANT", "Roman Abramovich owns Chelsea F.C., where Steve Holland plays."),
72
+ ("USER", "Question: Who is the chancellor of Falmouth University? Triplets: ('Falmouth University', 'chancellor', 'Dawn French');"),
73
+ ("ASSISTANT", "The chancellor of the Falmouth University is Dawn French.")
74
 
75
  ]
76
 
77
 
78
  history_wo_q = [
79
+ ("USER", "('West Tilbury', 'Ceremonial County', 'Essex');"),
80
+ ("ASSISTANT", "Essex is the Ceremonial County of West Tributary"),
81
+ ("USER", "('Jaime Bateman Cay贸n', 'death place', 'Panama'); ('Hornito, Chiriqu铆', 'country', 'Panama');"),
82
+ ("ASSISTANT", "Hornito, Chiriqu铆 is located in Panama, where Jaime Bateman Cay贸n died."),
83
+ ("USER", "('Steve Holland', 'current club', 'Chelsea F.C.'); ('Chelsea F.C.', 'owner', 'Roman Abramovich');"),
84
+ ("ASSISTANT", "Roman Abramovich owns Chelsea F.C., where Steve Holland plays."),
85
+ ("USER", "('Falmouth University', 'chancellor', 'Dawn French');"),
86
+ ("ASSISTANT", "The chancellor of the Falmouth University is Dawn French.")
87
 
88
  ]
89
 
 
92
  def prepare_input(linearized_triplets, question=None) -> str:
93
  if question and "List all" in question:
94
  question = question.replace("List all ", "Which are ")
95
+ if question:
96
  input_text = f"Question: {question.strip()} Triplets: {linearized_triplets}"
97
  else:
98
  input_text = linearized_triplets
 
104
  instruction: str,
105
  history: List[Tuple[str, str]]=None,
106
  ) -> str:
107
+ ret = f"{instruction}\n"
108
  for i, (role, message) in enumerate(history):
109
+ ret += f"{role}: {message}\n"
110
+ ret += f"USER: {curr_input}\nASSISTANT: "
111
  return ret
112
 
113
 
 
125
  else:
126
  instruction = make_prompt(curr_input, instruction_wo_q, history_wo_q)
127
 
128
+ stop = StopOnTokens()
129
  input_ids = tokenizer(instruction, return_tensors="pt").input_ids
130
  input_ids = input_ids.to(model.device)
131
 
 
137
  top_p=top_p,
138
  top_k=top_k,
139
  repetition_penalty=repetition_penalty,
 
140
  )
141
 
142
  with torch.no_grad():
143
+ outputs = model.generate(**generate_kwargs, return_dict_in_generate=True, output_scores=True)
144
 
145
  response = tokenizer.decode(outputs.sequences[0], skip_special_tokens=True)
146
  for tok in tokenizer.additional_special_tokens+[tokenizer.eos_token]:
 
183
  examples = gr.Examples([
184
  ['("Google Videos", "developer", "Google"), ("Google Web Toolkit", "author", "Google")', ""],
185
  ['("Katyayana", "religion", "Buddhism")', "What is the relegious affiliations of Katyayana?"],
186
+ ], inputs=[triplets, question, temperature, top_p, top_k, repetition_penalty], fn=generate_output, cache_examples=False if platform.system() == "Windows" or platform.system() == "Darwin" else True, outputs=output_box)
187
 
188
 
189
  #readme_content = requests.get(f"https://huggingface.co/HF_MODEL_PATH/raw/main/README.md").text
 
194
  # readme_content,
195
  # )
196
 
197
+ run_button.click(fn=generate_output, inputs=[triplets, question, temperature, top_p, top_k, repetition_penalty], outputs=output_box, api_name="rdf2text")
198
  clear_button.add([triplets, question, output_box])
199
 
200
  demo.queue(concurrency_count=1, max_size=10).launch(debug=True)