Transcription oddities - Whisper model hallucinations

#2
by randomishwalk - opened

@sanchit-gandhi + others - anyone notice that when transcribing particularly long YouTube videos, Whisper-JAX occasionally might pick up on a phrase and repeat it like 20-30x?

I tried this on a ~120min long YouTube video w/ multiple speakers and there were several instances in the produced transcript of this issue. Did not toggle on the timestamp feature.

I'll try to see if I can reproduce this using different length videos, audio files, the timestamp feature & also try to compare it to the pytorch/vanilla version of the Whisper API to see if I can reproduce this issue...

Hey @randomishwalk ,

Could you send the link of the video here maybe? :-)

Yes! It's this one: https://www.youtube.com/watch?v=fWRboyGk_lc

Was planning next week to try a couple other videos and different permutations of Whisper / Whisper JAX

Hey @randomishwalk - thanks for flagging this. The Whisper model has an intrinsic propensity to hallucinate by repeating the same word / phrase over and over again (see https://github.com/openai/whisper/discussions?discussions_q=is%3Aopen+hallucinate). This effect is also seen in the PyTorch model (you can try using the space https://huggingface.co/spaces/sanchit-gandhi/whisper-large-v2, but it'll be very slow for a 2hr video). We could mitigate this by adding a penalty to repeated ngrams. We have a logits processor to do this in PyTorch (see https://huggingface.co/docs/transformers/main_classes/text_generation#transformers.GenerationConfig.no_repeat_ngram_size). We'd have to implement something in Flax in a way that would still allow the function to be JIT compiled.

Alternatively (this won't fix Whisper JAX itself) could simply run the raw output from this model through GPT with an explicit prompt to correct for hallucinations & abnormally repeated phrases

Even simpler would be to write a simple heuristic to remove repeated n-grams on the output! We'd get a nice generation speed-up by implementing this as a a logits processor though

randomishwalk changed discussion title from Transcription oddities to Transcription oddities - Whisper model hallucinations

Oh, you're right. That is way more elegant. I was about to suggest hacking together something unnatural with Langchain >.<

The original repository solves this by predicting timestamps - the timestamps should help reduce the model's propensity to hallucinate. Does this help for you?

Hey @randomishwalk - could you try again with your YouTube video? The hallucinations should be drastically reduced (albeit not removed entirely)

This is a significant improvement, awesome work @sanchit-gandhi ! See a Dropbox link here for the results--I compared the raw transcripts produced from the previous version of Whisper JAX (should the one from mid/late-April) to this latest one and tracked changes.

Very cool! Looks much better now! And inference time should be a bit faster too 🤞

Closing this as complete since we're now limited by the Whisper model rather than the JAX/TPU implementation - feel free to open a new issue if something looks off!

sanchit-gandhi changed discussion status to closed

Hey @sanchit-gandhi , can we do something on model level or training level changes (liking using some alignment loss or any other ) in whisper to reduce hallucination . As n-gram penalty seems bit post processing techniques .

Sign up or log in to comment