How do i perform distillation for this model?

#4
by pr4nav101 - opened

I have been trying to perform distillation from a finetuned 8B model to a llama 3.2 3B model, I am importing them using tranformers AutoModelForCasualLM,
I have setup all the code for distillation training but at :
-----> teacher_outputs = teacher(input_ids, attention_mask=attention_mask) #used to extract logits or probability distribution
logits = teacher_outputs.logits
I am getting
bsz, q_len, _ = hidden_states.size()
ValueError: not enough values to unpack (expected 3, got 2)

same for the student outputs.
I could only find resources for distillation for distillBert and AutoModelForSeq2SeqLM, Can I know how I can fix this error, or how I can perform distillation for my use case or is it just better to fine tune a 3B model directly ?

Sign up or log in to comment