Hacker News new | past | comments | ask | show | jobs | submit login
Llama from scratch, or how to implement a paper without crying (briankitano.com)
513 points by bkitano19 on Aug 9, 2023 | hide | past | favorite | 52 comments



There is a bug : While in SwiGLU beta is a learnable parameter, in the reference paper the feed forward network set beta as a constant FFnSwiGLU = Swish1... https://arxiv.org/pdf/2002.05202.pdf (Eq 6.)

In official llama implementation : the constant beta has been removed https://github.com/facebookresearch/llama/blob/main/llama/mo...

In the blog log we observe various lines " feedforward.1.beta', 0.0 " which mean that during the training the beta has degenerated into 0 whereas it should be constant 1.


I guess this goes to show how challenging it can be to implement transformer neural networks correctly. There are so many ways in which you can make mistakes at various steps, and there is no surefire way of knowing, you'll just have a slightly worse performance than you would've gotten otherwise. And in many cases, if you make a change to the network, either intentionally or not, the network adapts to it and there are many examples of different variants of the architecture performing similarly once trained. (though, in these cases, one might ask if it really matters if you match the original or not?)

One method I've seen people do to identify these types of mistakes is by precisely matching model outputs with a reference implementation. HuggingFace does this with tiny-random models: these models have randomized weights, but the output is expected to match exactly, if not, then it's an indicator of a bug. But this approach only works for bugs that arise during inference, detecting issues in data processing, optimizers, or anything that only happens during training is more challenging.


And since there is Huggingface transformers, you can also test against that, which is what we do in Curated Transformers (transformers is only a test-time dependency).


The model really wants to learn, but it would use any shortcut to do it.


Wow, great catch. I will update this in the morning!


Cool, there are also additional issues with the RoPEAttention you might want to fix as well :

The reference paper for rotary embedding is Roformer https://arxiv.org/pdf/2104.09864v4.pdf

First you shouldn't rotate the values, only keys and queries. This is wrong : v_out = (torch.bmm(v.transpose(0,1), self.R[:m, ...])).transpose(0,1)

Second you shouldn't apply multihead attention which as additional inner weights that will mess with the rotations you have just done. This is wrong : activations, attn_weights = self.multihead (q_out,k_out,v_out)

Instead you should use scaled_dot_product_attention( q_out,k_out,v_out)

Third, each attention head should have been treated similarly, and each attention head should have the same rotation frequencies.


> Second you shouldn't apply multihead attention which as additional inner weights that will mess with the rotations you have just done

wait does that mean that rotary embeddings don't work with multiheaded attention? First I have heard of this. Wouldn't this be an issue with position embeddings as well (for example sinusoidal position embeddings are a special case of rotary embeddings)?


Afaiu, the whole idea behind rotary embeddings is kind of a hack to switch the similarity metric (that compares query to keys) inside the scaled_dotproduct_attention without having to rewrite the optimized code of scaled_dotproduct_attention.

This custom similarity metric has some properties engineered into it, mainly some invariance with relative positioning, and learnable decay with increasing distance (keys-query similarity decrease with increasing distance in position space and the network can learn how important is position distance compared to feature-space distance). It's a strong prior that works well when relative positioning is important.

It's a refinement of the traditional attention : It's a different and more ambitious aim than what sinusoidal position are trying to do, which is just provide some position information to the neural network so that it can distinguish keys and let it learn what it sees fit.

Sinusoidal position embeddings can learn some relative positioning quite easily because of trigonometry, but they have to learn it. Rotary embeddings have relative positioning baked in : everything is relative to the query position (quite similar point of view as a convolutional network), and the only thing they learn is how important small position distance compared to high position distance should be.


Generally biases in transformers don't work so well.

Personally I think it's because of the autoregressive, ODE-like nature of them, but who am I to say anything on that. ;PPPP


Kudos for the work! Stupid comment (not really on the main topic of the blogpost, but might be useful anyway for future "toy example" models): in the initial SimpleBrokenModel class [EDIT: and also in SimpleModel), there is actually quite a bit of wasted computation (something like > 66% of all the model computations!). You are applying, in sequence, the following layers:

- embedding 65 -> 128

- linear 128 -> 128

- ReLU

- linear 128 -> 65

But since there's no non-linearity at all between the first two layers, and they both are linear... the second one is totally useless. This model is effectively a "classical" single hidden layer MLP. And in terms of FLOPS, it's wasting 128128=16k operations out of a total of 128128+65*128=24k operations.


Seems I'm not the only one still getting to grips with non-linearity, lol (see discussion down-thread).

So what's the best fix here? Adding a ReLU or SwiGLU between the embedding and first linear layer, or just deleting the linear? As presumably the embedding layer is required to convert token indexes to the embedding vector and you can't get rid of that, it has a special structure.


Well it depends what you mean by “best” :-) removing the linear layer is the easiest solution (indeed you can’t remove the embedding one; in theory you could replace embedding + linear by one hot encoding + linear, adapting the input dimension or the linear layer to match your vocabulary size, but that would just be identical to embedding layer, just much slower and more memory hungry).

Alternatively, you could indeed put a ReLU or other non linearity between embedding and linear, you get a different model with more layers and more parameters, as the given dataset is pretty large I’m quite sure this would bring an improvement to accuracy, but without testing it’s rather impossible to know. Normalisation also acts as some kind of non linearity, but when the author adds it that barely helps accuracy at all, so who knows, sometimes (often) neural networks are counter intuitive…


Why does adding a ReLU create more layers and parameters? Isn't the total number of neurons the same?


The representational capacity of two consecutive linear layers is the same as one slightly different linear layer. The capacity when you introduce a relu into the mix is (up to a complexity defined by the number of parameters) any "nice" function -- including things like e^sin(x) -- not just linear functions. With two consecutive linear layers many of the weights and computations are redundant.


Right, I get that: it increases learning capacity, but doesn't introduce more parameters? Like the GPU requirements would be the same beyond the extra cost of the ReLU operation itself, yes?


Yes of course, sorry my write-up was confusing: I meant that "adding a ReLU between the two linear layers" (the second option) would result in more parameters than "directly removing the second linear layer" (the first option). And my message just meant "I don't know which of the two options achieves the best trade-off between speed and quality". I didn't consider the option "leave it as it is in the blog post" because it is essentially equivalent to the first option (removing the linear layer) but slower (as you say, with exactly the same number of parameters as the second option), so it definitely shouldn't be a "best" option.


Thank you!


Overall, a good sense of fundamental principles demonstrated.

Particularly:

"Use .shape religiously. assert and plt.imshow are your friends." Thank you. You should always assert pre and post conditions of shape. (Do bear or typeguard allow you to do this using decorators?)

Some nits:

"Before you even look at the paper, pick a small, simple, and fast model that you've done in the past. Then make a helper function to evaluate the model qualitatively." Don't you mean quantitatively? So that you establish a numerical baseline against which you can compare the more advanced method.

"Start by picking apart different components of the paper, and then implementing them one-by-one, training and evaluating as you go." Can you be precise what you mean here? A lot of work is like: "Okay we tried 10 changes things [for unspecified reasons], some major and some minor, to get our final thing, and here's an ablation study to show how much we lose if we remove each piece." If you would say: "Implement the meat first (the major architectural change fundamental to the work, i.e. the ablation study line-item all the way at the bottom with no seasoning or spices on it)" then yeah, that's a good place to start. But you can't start with a broccoli recipe, switch to a meat recipe, and taste it halfway before it's done cooking and you haven't flipped it, you're not going to learn much. This sort of advance is better framed as: "Evaluate each time you make an atomic change to the approach, prioritizing changes in the order that had the most impact in the ablation study from easiest to hardest, respecting the DAG in which certain changes can be made."


> (Do bear or typeguard allow you to do this using decorators?)

You can push some of this directly into Python type annotations thanks to https://peps.python.org/pep-0646/.

e.g.

  @overload
  def mean(a: ndarray[float, Dim1, *Shape], axis: Literal[0]) -> ndarray[float, *Shape]: ...
  @overload
  def mean(a: ndarray[float, Dim1, Dim2, *Shape], axis: Literal[1]) -> ndarray[float, Dim1, *Shape]: ...


I’m not sure about PyTorch (last I checked, no, but it’s been a while), but Jax offers rudimentary runtime checking support for matrix shapes via bear / type guard.

Ultimately, though, I don’t think Python will be nearly as good at that as Julia, whose type system can easily ensure matrix sizes make sense.


What is the guiding principle behind using Swiglu instead of Relu? Did the authors decide by simply trying all available non linearities or is there a deeper reason.


Like a lot of research, unless there’s a clear explanation supported by rigorous study, they probably randomly hillclimbed a bunch of cool new one liner changes and stopped when it was time to start writing the paper and doing ablation studies.


To be less glib, just wait until there are a bunch of papers picking Swiglu over Relu, and then you can stop handwringing. Because it doesn't really matter if there was a super specific concrete well-articulated reason that Swiglu worked well for their specific approach. You're still going to use Relu by default and quickly try Swiglu for now regardless.

It's fine, I waited a bit before default adopting Relu over Tanh for all hidden non-final (not outputting a probability) layers.


Thanks a lot for your explanations :)


edit: bearblog getting ddos'd, here's the repo https://github.com/bkitano/llama-from-scratch


For AI learners like me, here's an attempt to briefly explain some of the terms and concepts in this blog post, in the rough order they appear.

A token is a unique integer identifier for a piece of text. The simplest tokenization scheme is just Unicode where one character gets one integer, however LLMs have a limited number of token IDs available for use (the vocabulary), so a more common approach is to glue characters together into common fragments. This post just uses the subset of ASCII needed by TinyShakespeare.

The "loss function" is just a measure of how similar the model's prediction is to the ground truth. Lower loss = better predictions. Different tasks have different loss functions, e.g. edit distance might be one (but not a good one). During training you compute the loss and will generally visualize it on a chart. Whilst the line is heading downwards your NN is getting better, so you can keep training.

PyTorch is a library for working with neural networks and tensors. A tensor is either a single number (0 dimensions, a scalar), an array of numbers (1 dimension, a vector), or a multi-dimensional array of numbers where the 2-dimensional case is called a matrix. But a tensor can have any number of dimensions. PyTorch has a relatively large amount of magic going on in it via reflection and other things, so don't expect the code to make much intuitive sense. It's building a computation graph that can be later executed on the GPU (or CPU). The tutorial is easy to read!

A neural network is a set of neurons, each of which has a number called the bias, and connections between them each of which has an associated weight. Numbers (activations) flow from an input neuron through the connections whilst being adjusted by the weights to arrive at an output neuron, those numbers are then summed then multiplied by the bias before being emitted again to the next layer. The weights and biases are the network parameters and encode its knowledge.

A linear layer is a set of input neurons connected to a set of output neurons, where every input is connected to every output. It's one of the simplest kinds of neural network structure. If you ever saw a diagram of a neural network pre-2010 it probably looked like that. The size of the input and output layers can be different.

ReLU is an activation function. It's just Math.max(0, x) i.e. it sets all negative numbers to zero. These are placed on the outputs of a neuron and are one of those weird mathematical hacks where I can't really explain why it's needed, but introducing "kinks" in the function helps the network learn. Exactly what "kinks" work best is an open area of exploration and later the author will replace ReLU with a newer more complicated function.

Gradients are kind of numeric diffs computed during training that are used to update the model and make it more accurate.

Batch normalization is a way to process the numbers as they flow through the network, which helps the network learn better.

Positional encodings help the network understand the positions of tokens relative to each other, expressed in the form of a vector.

The `@` infix operator in Python is an alias for the __matmul__ method and is used as a shorthand for matrix multiplication (there are linear algebra courses on YouTube that are quite good if you want to learn this in more detail).

An epoch is a complete training run of the dataset. NNs need to be shown the data many times to fully learn, so you repeat the dataset. A batch is how many of the items in the dataset are fed to the network before updating the parameters. These sorts of numbers are called hyperparameters, because they're things you can fiddle with but the word parameters was already used for weights/biases.

Attention is the magic that makes LLMs work. There are good explanations elsewhere, but briefly it processes all the input tokens in parallel to compute some intermediate tensors, and those are then used in a second stage to emit a series of output tokens.


One more for the list is that a lot of people don't know what "Karpathy" means unless they are in the field and have been reading papers.

It might be good to include context like "the science communicator/researcher, Andrej Karpathy" so that it is clearer that it is referring to a useful person to look at posts from.


Another learner here, one clarification that I think is useful even for beginners:

> A token is a unique integer identifier for a piece of text.

A token is a word fragment that's common enough to be useful on its own - for eg., "writing", "written", "writer" all have "writ", so "writ" would be an individual token, and "writer" might be tokenized as "writ" and "er".

An embedding is where the tokens get turned into unique numeric identifiers.


Tokens are also numbers in practice, but they're indexes into a lookup table of character sequences so yes there's very little between the two definitions. Embeddings are in turn the result of looking up that index in a table, and the result is a vector. So:

character sequence (string) -> token (small integer) -> embedding (vector of floats)


The tokens are in this case actually the individual characters:

    vocab = sorted(list(set(lines)))


>These are placed on the outputs of a neuron and are one of those weird mathematical hacks where I can't really explain why it's needed,

Because when you compose linear functions you get linear functions. So having linear everything is a waste of all layers but one.

In order for this not to happen, you need nonlinearity.


Thanks!


This is fantastic, thanks!

Any pointers / references / books that you’ve found particularly helpful in your learning journey?

I know about Karpathy’s video series (and accompanying repos). Anything else come to mind? Thanks!


I've been using a pretty random mix of things including the PyTorch tutorial, some of the tutorials on how transformers work that got posted here months ago, reading papers, and (of course) asking GPT4. It probably isn't the most efficient way to learn.

I would say that learning how to actually build NNs is likely not that important. What's far more important is to know how to use LLMs as an API or library, which is of course 1% coding because the API is so easy and 99% figuring out what their limits are, how best to integrate them into workflows, how to design textual "protocols" to communicate with the AI, how to test non-deterministic systems and so on. Learning how to train a model from scratch is fun but to get competitive results is too expensive, so pragmatism requires focus on being a user for now.


Use perplexity.ai . Why not use AI to learn AI! The good thing I like about this tool is that it gives citations, so that you can learn further beyond summarization it does.


Thank you! What is batch normalization doing and how does it help


There are other mechanisms for dealing with vanishing and exploding gradients. I (maybe wrongly?) think of batch normalization as being most distinctively about fighting internal covariate shift: https://machinelearning.wtf/terms/internal-covariate-shift/


Karpathy covers this in Makemore, but the tl;dr is that if you don’t normalize the batch (essentially center and scale your activations down to be normally distributed), then at gradient/backprop time, you may get values that are significantly smaller or greater than 1. This is a problem, because as you stack layers in sequence (passing outputs to inputs), the gradient compounds (because of the Chain Rule), and so what may have been a well behaved gradient at the end layers has either vanished (the upstream gradients were 0<x<1 at each layer) or exploded (the gradients were x>>1 upstream). Batch normalization helps control the vanishing/exploding gradient problem in deep neural nets by normalizing the values passed between layers.


got it,thanks


It's another one of those mathematical hacks that NNs love so much, which stops the numbers spiralling out of control in big networks.


folks thanks for the explanation


Seriously great post - One of those that I read and immediately starting wishing I had read something like this a few years ago when it was all still a bit alien to me and had it explained in less.. digestible bits. Regardless I got a ton out of this very well done


Whenever there is some working existing implementation of a model (and maybe even checkpoint), the most effective way to be sure your model implementation is correct is to import such an existing checkpoint and compare the model output. If it does not match (which is almost always the case, as you likely got some details wrong), you can systematically go through each of the layers. You will figure out the real differences and learn. Maybe you will even find some oddities in the existing implementation.

This is about the model itself. Training is another aspect. But usually after having the hyper parameters more or less similar, this should be fine, if the model is correct.


love it! great content! both how to read a paper and of course the content of this specific paper! and I recommend as well the Karpathy's Makemore series!


The TL;DR pointers are really great, and the note about asserting the shape of tensors applies any common linear algebra library out there, as far as I know. When working on complex LA code, it's extremely important to take small steps and code defensively. In my opinion programming linear algebra with any mainstream language is absolutely atrocious due to lack of compile-time checking of tensor shapes, which should properly be part of a tensor's type and would make it impossible to compile if you're trying to multiply a 3x4 by a 3x4 without transposing first. It really, really sucks to run a long calculation only to fail on an operation due to mismatched dimensions.

IMO PyTorch tensors should also have their device statically typed; right now you get a run-time error if you try to multiply a tensor in CPU memory by one in GPU memory.


Llama is one of the nicer papers to read IMO.


Looks like we DDoS'd the server...


working on it hehe


found a typo. search the body for "isntead" ...



This is amazing. Thanks for sharing!


I will ravage the world's network security system




Join us for AI Startup School this June 16-17 in San Francisco!

Guidelines | FAQ | Lists | API | Security | Legal | Apply to YC | Contact

Search: