BAAI
/

ryanzhangfan commited on
Commit
6244a14
·
verified ·
1 Parent(s): 3578624

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +10 -4
README.md CHANGED
@@ -100,7 +100,8 @@ GENERATION_CONFIG = GenerationConfig(
100
  top_k=2048,
101
  )
102
 
103
- h, w = pos_inputs.image_size[0]
 
104
  constrained_fn = processor.build_prefix_constrained_fn(h, w)
105
  logits_processor = LogitsProcessorList([
106
  UnbatchedClassifierFreeGuidanceLogitsProcessor(
@@ -118,7 +119,8 @@ logits_processor = LogitsProcessorList([
118
  outputs = model.generate(
119
  pos_inputs.input_ids.to("cuda:0"),
120
  GENERATION_CONFIG,
121
- logits_processor=logits_processor
 
122
  )
123
 
124
  mm_list = processor.decode(outputs[0])
@@ -139,12 +141,16 @@ inputs = processor(
139
  padding="longest",
140
  return_tensors="pt",
141
  )
142
- GENERATION_CONFIG = GenerationConfig(pad_token_id=tokenizer.pad_token_id, bos_token_id=tokenizer.bos_token_id, eos_token_id=tokenizer.eos_token_id)
 
 
 
 
 
143
 
144
  outputs = model.generate(
145
  inputs.input_ids.to("cuda:0"),
146
  GENERATION_CONFIG,
147
- max_new_tokens=1024,
148
  attention_mask=inputs.attention_mask.to("cuda:0"),
149
  )
150
  outputs = outputs[:, inputs.input_ids.shape[-1]:]
 
100
  top_k=2048,
101
  )
102
 
103
+ h = pos_inputs.image_size[:, 0]
104
+ w = pos_inputs.image_size[:, 1]
105
  constrained_fn = processor.build_prefix_constrained_fn(h, w)
106
  logits_processor = LogitsProcessorList([
107
  UnbatchedClassifierFreeGuidanceLogitsProcessor(
 
119
  outputs = model.generate(
120
  pos_inputs.input_ids.to("cuda:0"),
121
  GENERATION_CONFIG,
122
+ logits_processor=logits_processor,
123
+ attention_mask=pos_inputs.attention_mask.to("cuda:0"),
124
  )
125
 
126
  mm_list = processor.decode(outputs[0])
 
141
  padding="longest",
142
  return_tensors="pt",
143
  )
144
+ GENERATION_CONFIG = GenerationConfig(
145
+ pad_token_id=tokenizer.pad_token_id,
146
+ bos_token_id=tokenizer.bos_token_id,
147
+ eos_token_id=tokenizer.eos_token_id,
148
+ max_new_tokens=1024,
149
+ )
150
 
151
  outputs = model.generate(
152
  inputs.input_ids.to("cuda:0"),
153
  GENERATION_CONFIG,
 
154
  attention_mask=inputs.attention_mask.to("cuda:0"),
155
  )
156
  outputs = outputs[:, inputs.input_ids.shape[-1]:]