Skip to content

16. MAKING PREDICTION with LLAMA MODEL - 3

In previous 14. MAKING PREDICTION with LLAMA MODEL - 1, 15. MAKING PREDICTION with LLAMA MODEL - 2, and this chapters, we walk through the LlamaTransformer.Forward(...) method and its components.

Now, we have the currentTensor, which is the resulting tensor after running all 32 transformer block layers. These transformer block layers were explained in the previous 15. MAKING PREDICTION with LLAMA MODEL - 2 chapter.

STAGE 19: Forward Pass Through Output of The Transformer Diagram Diagram: Forward Pass Through Output of The Transformer. For the complete diagram, click here.

16.1. Performing Forward Pass Through Output Prenormalization - RMSNorm.Forward(...)

The Llama 3.1 models use Pre-RMSNorm (Root Mean Square Layer Normalization). Because of we perform Root Mean Square Layer Normalization before performing multiplication of current tensor with normalization weights tensor, we call this normalization stage as "pre-normalization".

In this source, it writes:
RMSNorm regularizes the summed inputs to a neuron in one layer according to root mean square (RMS), giving the model re-scaling invariance property and implicit learning rate adaptation ability. RMSNorm is computationally simpler and thus more efficient than LayerNorm.

from src/model/llamatransformer.go

func (lt *LlamaTransformer) Forward(infContext *InferenceContext, inputTokens *ml.Tensor, startPos int) (*ml.Tensor, error) {
    ...
    common.GLogger.DebugPrintf("Calling RMSNorm for currentTensor shape(%v) (result of all transformer blocks) and LlamaTransformer.output_norm weights shape(%v) -> tensor currentTensor", currentTensor.Size, lt.output_norm.weights.Size)
    if currentTensor, err = lt.output_norm.Forward(infContext, currentTensor); err != nil {
        return nil, err
    }
    ...
}

We can see output lines in the "debug.log" file if debugging is enabled, as follows:

[DEBUG] ... Calling RMSNorm for currentTensor shape([22 4096]) (result of all transformer blocks) and LlamaTransformer.output_norm weights shape([4096]) -> tensor currentTensor ...

In RMSNorm.Forward(...) method, we call the RMSNorm.doNormalization(...) method, then perform an element-wise multiplication with LlamaTransformer.output_norm normalization weights tensor via ml.MultiplyElementwise.

from src/model/llamatransformer.go

func (rms *RMSNorm) Forward(infContext *InferenceContext, x *ml.Tensor) (*ml.Tensor, error) {
    h, err := rms.doNormalization(x)
    if err != nil {
        return nil, err
    }
    return ml.MultiplyElementwise(h, rms.weights)
}

The RMSNorm.doNormalization(...) method consists of multiple steps:

  • Take currentTensor which is the resulting tensor after running all 32 transformer block layers as input x tensor with shape of {22, 4096},
  • Calculate square of each item in the x tensor via h, err = ml.Pow(x, 2) and assign it to h tensor,
  • Calculate mean values of last dimension, without removing the last dimension: input shape was {22, 4096}, output shape is {22, 1},

    For further information, check out: torch.mean documentation.

  • Add scalar value of 0.00001 at RMSNorm.epsilon to each item in the h tensor via ml.AddScalar. Because we have model.ModelArgs.NormEpsilon = 1e-05 as read from "params.json" configuration file. Output shape is {22, 1},

    The model configuration file "params.json" is parsed as ModelArgs object via JSON parser.

  • Calculate reciprocal of the square-root of each item in the h tensor via ml.RSqrt. Output shape is {22, 1},

    For further information, check out: torch.rsqrt documentation.
    The formula is:

\[ out_i = \frac{1}{\sqrt{input_i}} \]
  • Perform an element-wise multiplication with x input tensor with shape of {22, 4096} and h normalization tensor with shape of {22, 1} via ml.MultiplyElementwise. Output shape is {22, 4096},

Now, in RMSNorm.Forward(...) method, we have the h tensor with shape of {22, 4096} and RMSNorm.weights nromalization weights tensor with shape of {4096}.

Then, we perform an element-wise multiplication with the h tensor with shape of {22, 4096} and RMSNorm.weights nromalization weights tensor with shape of {4096}. Output shape is {22, 4096}.

from src/model/llamatransformer.go

func (rms *RMSNorm) doNormalization(x *ml.Tensor) (*ml.Tensor, error) {
    var err error
    var h *ml.Tensor
    if h, err = ml.Pow(x, 2); err != nil {
        return nil, err
    }
    if h, err = ml.Mean(h, -1, true); err != nil {
        return nil, err
    }
    if h, err = ml.AddScalar(h, rms.epsilon); err != nil {
        return nil, err
    }
    if h, err = ml.RSqrt(h); err != nil {
        return nil, err
    }
    if h, err = ml.MultiplyElementwise(x, h); err != nil {
        return nil, err
    }
    return h, nil
}

Now, In LlamaTransformer.Forward(...) method, we have the pre-normalized via RMSNorm tensor currentTensor with shape of {22, 4096}.

Sources:

16.2. Performing Linear Transformation with the Output Weights

We've done pre-normalization over the currentTensor.

We have the output weights tensor at lt.output already loaded to LlamaTransformer struct. In our case, shape of our output weights layer is {128256, 4096}.

Now we do matrix multiplication over currentTensor with shape of {22, 4096} and transpose of lt.output with shape of {128256, 4096} via ml.LinearTransformation. Output shape is {22, 128256}.

In our project, we have implemented two separate functions to perform matrix multiplication, because one is for direct matrix multiplication, other is for matrix multiplication with transpose of second argument (generally, a weights tensor). We've defined the first one as ml.MatMul(...) and the second one as ml.LinearTransformation(...).

from src/model/llamatransformer.go

func (lt *LlamaTransformer) Forward(infContext *InferenceContext, inputTokens *ml.Tensor, startPos int) (*ml.Tensor, error) {
    ...
    common.GLogger.DebugPrintf("Calling ml.LinearTransformation for currentTensor (normalized result of all transformer blocks) shape(%v) and LlamaTransformer.output weights shape(%v) -> tensor output", currentTensor.Size, lt.output.Size)
    output, err := ml.LinearTransformation(currentTensor, lt.output)
    ...
}

We can see output lines in the "debug.log" file if debugging is enabled, as follows:

[DEBUG] ... Calling ml.LinearTransformation for currentTensor (normalized result of all transformer blocks) shape([22 4096]) and LlamaTransformer.output weights shape([128256 4096]) -> tensor output ...

16.3. Converting the Output Tensor to Float32 Tensor and Returning It

Now, we have the output tensor that contains probabilities of each alternative token in our vocabulary. In our case, at the first iteration, the shape of this tensor is {22, 128256}. 22 stands for sequence length, 128,256 stands for the vocabulary size. This output tensor contains our logits, we perform the Argmax operation over this logits tensor. But here, we just convert our tensor items to float32, to make performing argmax easy via Tensor.ToFloat32(...) method.

How this output tensor is used was described in the chapter "13.5.2. Looping through sequence length" at 13. GENERATING NEXT TOKENS.

from src/model/llamatransformer.go

func (lt *LlamaTransformer) Forward(infContext *InferenceContext, inputTokens *ml.Tensor, startPos int) (*ml.Tensor, error) {
    ...
    common.GLogger.DebugPrintf("Converting output tensor shape(%v) to Float32 tensor -> tensor output", output.Size)
    if output, err = output.ToFloat32(); err != nil {
        return nil, err
    }
    common.GLogger.DebugPrintf("Returning tensor output: shape(%v)", output.Size)
    return output, nil
}

We can see output lines in the "debug.log" file if debugging is enabled, as follows:

[DEBUG] ... Converting output tensor shape([22 128256]) to Float32 tensor -> tensor output ...
[DEBUG] ... Returning tensor output: shape([22 128256]) ...

With this step, we have finished all of the steps of LlamaTransformer.Forward(...) method and made prediction of the next token with probabilities. The journey continues in the loop at the chapter "13.5.2. Looping through sequence length" at 13. GENERATING NEXT TOKENS.