Hello all,
We've spent quite some time benchmarking the best fine-tuning techniques for GPT-J at NLP Cloud (https://nlpcloud.io). Finding the best solution was not straightforward and we had to look at things like speed, server costs, ease of development, accuracy of the fine-tuned model... It took time but we ended up with a nice setup (and we are now officially proposing GPT-J fine-tuning + automatic deployment on our platform).
Here are our key takeaways:
- The best methodology seems to be the one from the Mesh Transformer Jax team: https://github.com/kingoflolz/mesh-transformer-jax/blob/master/howto_finetune.md
- Fine-tuning on GPU is not ideal. Even several GPUs used in parallel with Deepspeed can be very slow. We used 4 GPUs Tesla T4 in parallel, and it took 1h30 to only compute our first checkpoint (+ 80GB of RAM used...), for a training dataset made up of 20k examples. Maybe a GPU A100 would be worth a try.
- Fine-tuning on TPU is very efficient but it takes a TPU v3 because TPUs v2 are running out of memory. It takes around 15mns, for a training dataset made up of 20k examples, which is really awesome.
- The overall process is not straightforward as it takes several kind of conversions (converting the datasets to the right format, making a slim version of the model, converting the weights to Transformers...)
In the end this is worth the effort, because combining fine-tuning and few-shot learning makes GPT-J very impressive and suited for all sorts of use cases.
If you guys have different feedbacks about GPT-J fine-tuning, please don't hesitate to comment, I would love to have your opinion.
Hope you found the above useful!