kz209 commited on
Commit
68c64e4
1 Parent(s): d546b0e
pages/summarization_playground.py CHANGED
@@ -53,6 +53,7 @@ def generate_answer(sources, model_name, prompt):
53
 
54
  def process_input(input_text, model_selection, prompt):
55
  if input_text:
 
56
  response = generate_answer(input_text, model_selection, prompt)
57
  return f"## Original Dialogue:\n\n{input_text}\n\n## Summarization:\n\n{response}"
58
  else:
 
53
 
54
  def process_input(input_text, model_selection, prompt):
55
  if input_text:
56
+ logging.info("Start generation")
57
  response = generate_answer(input_text, model_selection, prompt)
58
  return f"## Original Dialogue:\n\n{input_text}\n\n## Summarization:\n\n{response}"
59
  else:
utils/model.py CHANGED
@@ -23,6 +23,7 @@ class Model(torch.nn.Module):
23
 
24
  self.tokenizer = AutoTokenizer.from_pretrained(model_name)
25
  self.name = model_name
 
26
  self.pipeline = transformers.pipeline(
27
  "summarization",
28
  model=model_name,
@@ -31,7 +32,7 @@ class Model(torch.nn.Module):
31
  device_map="auto",
32
  )
33
 
34
- logging.info(f'Load model {self.name}')
35
  self.update()
36
 
37
  @classmethod
@@ -48,13 +49,24 @@ class Model(torch.nn.Module):
48
  return self.pipeline
49
 
50
  def gen(self, content, temp=0.1, max_length=500):
51
- sequences = self.pipeline(
52
- content,
53
- max_new_tokens=max_length,
54
- do_sample=True,
55
- temperature=temp,
56
- num_return_sequences=1,
57
- eos_token_id=self.tokenizer.eos_token_id,
58
- )
 
 
 
 
 
 
 
 
 
 
 
59
 
60
  return sequences[-1]['summary_text']
 
23
 
24
  self.tokenizer = AutoTokenizer.from_pretrained(model_name)
25
  self.name = model_name
26
+ logging.info(f'start loading model {self.name}')
27
  self.pipeline = transformers.pipeline(
28
  "summarization",
29
  model=model_name,
 
32
  device_map="auto",
33
  )
34
 
35
+ logging.info(f'Loaded model {self.name}')
36
  self.update()
37
 
38
  @classmethod
 
49
  return self.pipeline
50
 
51
  def gen(self, content, temp=0.1, max_length=500):
52
+ if self.name == "google-t5/t5-large":
53
+ sequences = self.pipeline(
54
+ content,
55
+ max_new_tokens=max_length,
56
+ do_sample=True,
57
+ temperature=temp,
58
+ num_return_sequences=1,
59
+ eos_token_id=self.tokenizer.eos_token_id,
60
+ )
61
+ else:
62
+ sequences = self.pipeline(
63
+ content,
64
+ max_new_tokens=max_length,
65
+ do_sample=True,
66
+ temperature=temp,
67
+ num_return_sequences=1,
68
+ eos_token_id=self.tokenizer.eos_token_id,
69
+ return_full_text=False
70
+ )
71
 
72
  return sequences[-1]['summary_text']