BAAI
/

ryanzhangfan commited on
Commit
40b3bb8
·
verified ·
1 Parent(s): c059b33

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +6 -4
README.md CHANGED
@@ -63,7 +63,7 @@ model = AutoModelForCausalLM.from_pretrained(
63
  trust_remote_code=True,
64
  )
65
 
66
- tokenizer = AutoTokenizer.from_pretrained(EMU_HUB, trust_remote_code=True)
67
  image_processor = AutoImageProcessor.from_pretrained(VQ_HUB, trust_remote_code=True)
68
  image_tokenizer = AutoModel.from_pretrained(VQ_HUB, device_map="cuda:0", trust_remote_code=True).eval()
69
  processor = Emu3Processor(image_processor, image_tokenizer, tokenizer)
@@ -81,6 +81,7 @@ kwargs = dict(
81
  ratio="1:1",
82
  image_area=model.config.image_area,
83
  return_tensors="pt",
 
84
  )
85
  pos_inputs = processor(text=prompt, **kwargs)
86
  neg_inputs = processor(text=NEGATIVE_PROMPT, **kwargs)
@@ -95,7 +96,8 @@ GENERATION_CONFIG = GenerationConfig(
95
  top_k=2048,
96
  )
97
 
98
- h, w = pos_inputs.image_size[0]
 
99
  constrained_fn = processor.build_prefix_constrained_fn(h, w)
100
  logits_processor = LogitsProcessorList([
101
  UnbatchedClassifierFreeGuidanceLogitsProcessor(
@@ -113,7 +115,8 @@ logits_processor = LogitsProcessorList([
113
  outputs = model.generate(
114
  pos_inputs.input_ids.to("cuda:0"),
115
  GENERATION_CONFIG,
116
- logits_processor=logits_processor
 
117
  )
118
 
119
  mm_list = processor.decode(outputs[0])
@@ -121,5 +124,4 @@ for idx, im in enumerate(mm_list):
121
  if not isinstance(im, Image.Image):
122
  continue
123
  im.save(f"result_{idx}.png")
124
-
125
  ```
 
63
  trust_remote_code=True,
64
  )
65
 
66
+ tokenizer = AutoTokenizer.from_pretrained(EMU_HUB, trust_remote_code=True, padding_side="left")
67
  image_processor = AutoImageProcessor.from_pretrained(VQ_HUB, trust_remote_code=True)
68
  image_tokenizer = AutoModel.from_pretrained(VQ_HUB, device_map="cuda:0", trust_remote_code=True).eval()
69
  processor = Emu3Processor(image_processor, image_tokenizer, tokenizer)
 
81
  ratio="1:1",
82
  image_area=model.config.image_area,
83
  return_tensors="pt",
84
+ padding="longest",
85
  )
86
  pos_inputs = processor(text=prompt, **kwargs)
87
  neg_inputs = processor(text=NEGATIVE_PROMPT, **kwargs)
 
96
  top_k=2048,
97
  )
98
 
99
+ h = pos_inputs.image_size[:, 0]
100
+ w = pos_inputs.image_size[:, 1]
101
  constrained_fn = processor.build_prefix_constrained_fn(h, w)
102
  logits_processor = LogitsProcessorList([
103
  UnbatchedClassifierFreeGuidanceLogitsProcessor(
 
115
  outputs = model.generate(
116
  pos_inputs.input_ids.to("cuda:0"),
117
  GENERATION_CONFIG,
118
+ logits_processor=logits_processor,
119
+ attention_mask=pos_inputs.attention_mask.to("cuda:0"),
120
  )
121
 
122
  mm_list = processor.decode(outputs[0])
 
124
  if not isinstance(im, Image.Image):
125
  continue
126
  im.save(f"result_{idx}.png")
 
127
  ```