Hacker Newsnew | past | comments | ask | show | jobs | submitlogin
Fine-tuning GPT-J: key takeaways
1 point by juliensalinas on Sept 23, 2021 | hide | past | favorite
Hello all,

We've spent quite some time benchmarking the best fine-tuning techniques for GPT-J at NLP Cloud (https://nlpcloud.io). Finding the best solution was not straightforward and we had to look at things like speed, server costs, ease of development, accuracy of the fine-tuned model... It took time but we ended up with a nice setup (and we are now officially proposing GPT-J fine-tuning + automatic deployment on our platform).

Here are our key takeaways:

- The best methodology seems to be the one from the Mesh Transformer Jax team: https://github.com/kingoflolz/mesh-transformer-jax/blob/master/howto_finetune.md - Fine-tuning on GPU is not ideal. Even several GPUs used in parallel with Deepspeed can be very slow. We used 4 GPUs Tesla T4 in parallel, and it took 1h30 to only compute our first checkpoint (+ 80GB of RAM used...), for a training dataset made up of 20k examples. Maybe a GPU A100 would be worth a try. - Fine-tuning on TPU is very efficient but it takes a TPU v3 because TPUs v2 are running out of memory. It takes around 15mns, for a training dataset made up of 20k examples, which is really awesome. - The overall process is not straightforward as it takes several kind of conversions (converting the datasets to the right format, making a slim version of the model, converting the weights to Transformers...)

In the end this is worth the effort, because combining fine-tuning and few-shot learning makes GPT-J very impressive and suited for all sorts of use cases.

If you guys have different feedbacks about GPT-J fine-tuning, please don't hesitate to comment, I would love to have your opinion.

Hope you found the above useful!



Guidelines | FAQ | Lists | API | Security | Legal | Apply to YC | Contact

Search: