13. GENERATING NEXT TOKENS¶
What we've done so far has been a preparatory stage for generating new tokens through text completion methods from the prompt tokens we have on hand, thus creating new outputs in natural language. Now, let the show begin!
13.1. Preliminary Concepts¶
13.1.1. Context¶
The Go language offers lots of robust tools to do parallel programming, with ensuring efficient management of concurrency.
From the context package documentation:
Package context defines the Context type, which carries deadlines, cancellation signals, and other request-scoped values across API boundaries and between processes.
In this project, we use context.WithCancel(...)
to create a new context with a cancel function for generating cancellation signal. In Go, we use lots of goroutines to run functions parallel. Mostly these goroutine
s initiates an infinite loop to wait some inputs from like go channels or signals like context cancellation signal.
This design sometimes leads to side-effects such as unfinished goroutines remaining active upon the intentional or unintentional termination of the main process. For instance, while you have unfinished goroutines running, one CTRL+C keystroke to terminate the process may not be enough, it requires you to press CTRL+C multiple times.
To prevent these types of side-effects and undesirable occurrences, we create a ctx context
along with a cancellation signal function named cancel
. When the cancel
function is invoked, all goroutines that check the status of the given ctx context
in every iteration of their infinite loops, will break their loops, and exit from their goroutine functions. This approach ensures a healthy and controlled termination process.
from cmd/main.go
func main() {
...
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
...
go listenGenerationChannels(&wg, ctx, generatedPartCh, errorCh)
...
}
Note that, we use
context.WithCancel(...)
method in our project, but there are more alternatives to instantiate a new context, check out the documentation.
See also:
13.1.2. Waiting loop¶
It starts with creating a wait group, it can be thought as a counter for parallel running goroutines/threads that needs to be added to the wait list. The waitGroup.Wait() method runs in a loop that waits until waitGroup item count becomes zero, so the process doesn't end.
from cmd/main.go
func main() {
...
var wg sync.WaitGroup
...
wg.Add(1)
go listenGenerationChannels(&wg, ctx, generatedPartCh, errorCh)
...
wg.Wait()
...
}
See also:
13.2. Calling GenerateString(...)¶
We call the InferenceEngine.GenerateString(...) method to start the generation process and retrieve channels generatedPartCh
and errorCh
. Then, we listen these two channels via goroutine listenGenerationChannels(...). At the end, the generation process may be finished by user termination (CTRL+C), unexpected error panic, reaching an EOS (End of Sentence) token, or reaching maximum sequence length specified at inferenceArgs.SequenceLength
. Then the application will print out the generated text on the console, then exit.
from cmd/main.go
func main() {
...
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
...
var wg sync.WaitGroup
generatedPartCh, errorCh := engine.GenerateString(tokens)
wg.Add(1)
go listenGenerationChannels(&wg, ctx, generatedPartCh, errorCh)
wg.Wait()
finishReason := "unknown"
switch appState.generationState {
case inference.GSFinishedByReachingEOS:
finishReason = "reaching EOS token"
case inference.GSFinishedByReachingSeqLen:
finishReason = "reaching sequence length"
}
fmt.Printf("\n\nFinished \033[1mby %s\033[0m.\n", finishReason)
}
13.3. Internals of GenerateString(...)¶
In this project, our objective was to implement streaming allowing us to print the generated text without waiting for the completion of the generation process, like what the ChatGPT does. This approach enables us to print out each generated token immediately after it was generated. However, this approach, while bringing some advantages, also comes with certain challenges that need to be tackled.
If you aren't familiar with this approach, please refer to Writing a Stream API in Go.
In our approach, we have separated the process into multiple methods and defined the InferenceEngine.GenerateStringGeneric(...)
method, because we call InferenceEngine.GenerateString(...)
method in the main application code, and we call InferenceEngine.GenerateStringFromOutputTokens(...)
method in the unit test of main application, which both commonly call InferenceEngine.GenerateStringGeneric(...)
method.
The InferenceEngine.GenerateStringGeneric(...)
method creates a channel of GeneratedPart and a channel of error
. Then calls InferenceEngine.generateStringInternal(...)
by giving these channels. This call is done as a goroutine to make it run parallel. At the end, the function returns these channels.
from src/inference/inference.go
func (ie *InferenceEngine) GenerateString(promptTokens []model.TokenId) (<-chan GeneratedPart, <-chan error) {
return ie.GenerateStringGeneric(promptTokens, ie.generateTokensInternal)
}
func (ie *InferenceEngine) GenerateStringGeneric(promptTokens []model.TokenId, tokenGeneratorFn TokenGeneratorFn) (<-chan GeneratedPart, <-chan error) {
// See: https://betterprogramming.pub/writing-a-stream-api-in-go-afbc3c4350e2
outputCh := make(chan GeneratedPart, 1)
outputErrorCh := make(chan error)
go func() {
defer func() {
close(outputCh)
close(outputErrorCh)
}()
ie.generateStringInternal(promptTokens, outputCh, outputErrorCh, tokenGeneratorFn)
}()
return outputCh, outputErrorCh
}
func (ie *InferenceEngine) generateStringInternal(promptTokens []model.TokenId, outputCh chan<- GeneratedPart, outputErrorCh chan<- error, tokenGeneratorFn TokenGeneratorFn) {
...
}
13.4. Internals of generateStringInternal(...)¶
This method creates a generationDecodingContext object that carries waiting bytes and waiting parts. Then it calls InferenceEngine.GenerateTokensGeneric(...)
method to make it create and return channels generatedTokensCh
and errorCh
. InferenceEngine.GenerateTokensGeneric(...)
runs parallelly, and calls the function given as tokenGeneratorFn
argument. This tokenGeneratorFn
function is InferenceEngine.generateTokensInternal(...)
for normal application, not for unit test.
Then, it initiates an infinite loop to wait upcoming new token from generatedTokensCh
or an error from errorCh
. If a new token was came from generatedTokensCh
, it processes the new token in consideration with waiting bytes and waiting parts, then sends it via outputCh
. If an error was came from errorCh
, it returns the error and exits.
While exiting, it checks if there's an item in decodingContext.waitingParts
, it processes them and sends them via outputCh
.
Waiting bytes and waiting parts are a temporary place for not processed tokens. For e.g., Llama models can generate emoji characters. More than one bytes form an emoji, as a Unicode character. Llama uses UTF-8 standard to encode a unicode character.
And, Llama models represent emojis with byte tokens like "<0xE2><0x9C>", "<0x88>", "<0xEF><0xB8><0x8F>". Let's say an emoji was encoded in 6 bytes, Llama encodes it as multiple different byte tokens in a few iterations. So, if we get a new generated byte token, we need to check if it requires another byte token to be rendered. Also, emojis can consist of multiple emojis, because of this, we need to handle these types of situations.For instance, the (Airplane) emoji is formed by 6 bytes, as 3 different byte tokens:
0xE2 0x9C 0x88 0xEF 0xB8 0x8F
."<0xE2><0x9C>"
,"<0x88>"
, and"<0xEF><0xB8><0x8F>"
byte tokens are generated respectively. You must be able to handle the first"<0xE2><0x9C>"
byte token when it was generated. When you take this new token, should you send it directly to the output channel to render, or add it to waiting bytes list to wait for the next"<0x88>"
and other required bytes to represent a valid UTF-8 character, that's the problem.The emoji rendering process will be discussed in a dedicated chapter further. But if you want to learn more, you can check out the unit tests prefixed as
TestSimulatedEmojiOutput...
in the main unit test.
from src/inference/inference.go
func (ie *InferenceEngine) generateStringInternal(promptTokens []model.TokenId, outputCh chan<- GeneratedPart, outputErrorCh chan<- error, tokenGeneratorFn TokenGeneratorFn) {
decodingContext := &generationDecodingContext{
waitingBytes: make([]byte, 0),
waitingParts: make([]GeneratedPart, 0),
}
lastGenerationState := GSInProgress
generatedTokensCh, errorCh := ie.GenerateTokensGeneric(promptTokens, tokenGeneratorFn)
loop := true
for loop {
select {
case generatedTokenIdResult, ok := <-generatedTokensCh:
if !ok {
loop = false
break
}
generatedToken, generatedTokenStr, addedToWaiting := ie.TokenToString(generatedTokenIdResult.value, decodingContext)
...
result := GeneratedPart{
...
}
...
outputCh <- result
case err, ok := <-errorCh:
if !ok || err == nil {
continue
}
outputErrorCh <- err
return
}
}
decodingContext.decodingFinished = true
if len(decodingContext.waitingParts) > 0 {
for i, waitingPart := range decodingContext.waitingParts {
result := GeneratedPart{
...
}
...
outputCh <- result
}
}
}
13.5. Internals of generateTokensInternal(...)¶
13.5.1. Preparing the input tokens tensor¶
We instantiate a new model.InferenceContext object as infContext
to keep temporary data about the current generation process, especially the CacheK
and CacheV
tensors that keep key and value. These concepts will be discussed in further chapters.
from src/model/inferencecontext.go
type InferenceContext struct {
SequenceLength int // context size used during inference
CacheK []*ml.Tensor
CacheV []*ml.Tensor
logFn func(format string, v ...any)
}
Diagram: Creating tokens tensor. For the complete diagram, click here.
A tensor named tokens
with DT_INT32
data type and with shape of {infContext.SequenceLength}
(in our default case it's {200}
) is instantiated. Then, this tensor will be filled with integer value in ie.model.Vocabulary.PadId
, which is -1
default.
After instantiation, prompt tokens is put into this tokens
tensor.
The tokens
tensor with shape {200}
for the prompt is:
promptString = "<|begin_of_text|><|start_header_id|>system<|end_header_id|>
You are Einstein<|eot_id|><|start_header_id|>user<|end_header_id|>
Describe your theory.<|eot_id|><|start_header_id|>assistant<|end_header_id|>
"
[128000, 128006, 9125, 128007, 271, 2675, 527, 55152, 128009, 128006, 882,
128007, 271, 75885, 701, 10334, 13, 128009, 128006, 78191, 128007, 271,
-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1]
shape=[200], dtype=Int32
from src/inference/inference.go
func (ie *InferenceEngine) generateTokensInternal(promptTokens []model.TokenId, generatedTokensCh chan<- generationStepResult[model.TokenId], errorCh chan<- error) {
infContext := ie.CreateInferenceContext()
...
tokens, err := ml.Full([]int{infContext.SequenceLength}, ml.DT_INT32, int32(ie.model.Vocabulary.PadId))
...
for i, token := range promptTokens {
if err := tokens.SetItem([]int{i}, int32(token)); err != nil {
errorCh <- err
return
}
}
common.GLogger.DebugPrintf("Created input tokens tensor: shape(%v)", tokens.Size)
...
}
13.5.2. Looping through sequence length¶
Now, we have a tokens
tensor with shape {200}
. This tensor's first 22 items tokens[0:22], indices between 0 and 21
are filled with our prompt tokens, remaining items are -1
.
We initiate a for loop with curPos
variable from 22 to 200-1
, which is promptLength to infContext.SequenceLength - 1
.
The first iteration of this loop takes all of prompt tokens as input, then other iterations takes the latest generated one token as input. Because other iterations use the data at the KV Cache (key-value cache) in infContext
.
Info: The term logits stands for a tensor containing probabilities of each alternative. In machine learning, particularly in neural networks, classification models provide their results in the form of logits. Output layers have neurons in count of output alternatives, each containing a float representing the probability of "this item is the prediction".
In our scenario, the vocabulary in the "tokenizer.model" file contains 128,256 tokens. Consequently, our logits tensor will have 128,256 items in one of its dimensions. Then we perform the Argmax operation over the logits tensor, which returns the index (token id) of the item with the highest probability.Additional info: In the LLM domain, we can use the temperature value to randomly select from the most likely alternative tokens, allowing for the generation of different outputs with each run. However, in our project, we haven't implemented this functionality, instead, we just return exactly the token with the highest probability.
Diagram: Looping through sequence length. For the complete diagram, click here.
The flow is:
-
The first iteration:
curPos = 22
,prevPos = 0
,inputTokensSlice
has 22 items, which are prompt tokens,- Execute
ie.model.Transformer.Forward(...)
to do first inference with our transformer model to retrieve logits, - We have
logits
tensor withDT_F32
data type and with shape of{22, 128256}
. -
But, we need only probabilities of the last sequence, we take the last one via
logits.Slice([]int{21}, []int{22}
: -
Now, our
logits
tensor's shape was become{1, 128256}
, - Execute ml.Argmax function over our
logits
, it will returnnextToken
tensor withDT_INT32
data type and with shape of{1, 1}
, - Take the only one item from the
nextToken
token vianextToken.Item()
asint32
intonextTokenId
variable, the value in our case is7979
, - Take
32th
item from thetokens
tensor intoexistingToken
variable, then if it is not-1
(ie.model.Vocabulary.PadId
), take the existing token as next token. This step was implemented by the original Llama 3.1 Python repository of Meta, and I have kept it as is, - Set
22th
item of thetensor
to7979
(value ofnextTokenId
), - Check if the
nextTokenId
equals toie.model.Vocabulary.EndOfSentenceId
, if yes, send a signal of EOS reached viageneratedTokensCh
channel, - Check if
curPos
reached to the sequence length, if yes, send a signal of maximum sequence length reached viageneratedTokensCh
channel, - Otherwise, send
nextTokenId
with a signal of in progress viageneratedTokensCh
channel, - Continue to loop with next iteration.
- Other iterations:
curPos = 23
,prevPos = 22
,- Take
32th
item from thetokens
tensor intoinputTokensSlice
, inputTokensSlice
has 1 item, which is the last generated token,- Execute
ie.model.Transformer.Forward(...)
to do first inference with our transformer model to retrieve logits, - Perform the same steps defined for "The first iteration" above.
from src/inference/inference.go
func (ie *InferenceEngine) generateTokensInternal(promptTokens []model.TokenId, generatedTokensCh chan<- generationStepResult[model.TokenId], errorCh chan<- error) {
...
prevPos := 0
for curPos := promptLength; curPos < infContext.SequenceLength; curPos++ {
inputTokensSlice, err := tokens.Slice([]int{prevPos}, []int{curPos})
if err != nil {
errorCh <- err
return
}
common.GLogger.DebugPrintf("=======================================\n\n")
common.GLogger.DebugPrintf("Running Transformer.Forward for curPos: %d, prevPos: %d, inputTokensSlice: shape(%v)", curPos, prevPos, inputTokensSlice.Size)
logits, err := ie.model.Transformer.Forward(infContext, inputTokensSlice, prevPos)
...
if logits, err = logits.Slice([]int{logits.Size[0] - 1}, []int{logits.Size[0]}); err != nil {
errorCh <- err
return
}
nextToken, err := ml.Argmax(logits, len(logits.Size)-1) // shape=[1,1] dtype=DT_INT32
if err != nil {
errorCh <- err
return
}
nextTokenId := model.TokenId(nextToken.Item().(int32))
// Comment in original Python code: only replace token if prompt has already been generated
existingToken, err := tokens.GetItem([]int{curPos})
if err != nil {
errorCh <- err
return
}
existingTokenId := model.TokenId(existingToken.(int32))
if existingTokenId != ie.model.Vocabulary.PadId {
nextTokenId = existingTokenId
}
if err = tokens.SetItem([]int{curPos}, int32(nextTokenId)); err != nil {
errorCh <- err
return
}
common.GLogger.DebugPrintf("Generated token for curPos: %d, prevPos: %d, token id: %d", curPos, prevPos, nextTokenId)
eosReached := nextTokenId == ie.model.Vocabulary.EndOfSentenceId
prevPos = curPos
if eosReached {
generatedTokensCh <- generationStepResult[model.TokenId]{
state: GSFinishedByReachingEOS,
value: nextTokenId,
}
break
}
if curPos+1 == infContext.SequenceLength {
generatedTokensCh <- generationStepResult[model.TokenId]{
state: GSFinishedByReachingSeqLen,
value: nextTokenId,
}
break
}
generatedTokensCh <- generationStepResult[model.TokenId]{
state: GSInProgress,
value: nextTokenId,
}
}
}
13.6. Calling listenGenerationChannels(...)¶
We've dove into the internals of some functions that generate new tokens and send them via the generatedPartCh
channel, starting from the InferenceEngine.GenerateString(...) method, so far.
We will discuss the details of ie.model.Transformer.Forward(...)
function in further chapters.
Now, we explain the listenGenerationChannels function shortly.
from cmd/main.go
func main() {
...
generatedPartCh, errorCh := engine.GenerateString(tokens)
wg.Add(1)
go listenGenerationChannels(&wg, ctx, generatedPartCh, errorCh)
wg.Wait()
...
}
This function runs as a goroutine
parallelly, listens for incoming new tokens from generatedPartCh
channel and errors from errorCh
, or a cancellation signal from ctx.Done()
via initiating an infinite loop.
If it receives an inference.GeneratedPart object from generatedPartCh
channel, it precesses the object, updates the console screen via appState.updateOutput()
method.
from cmd/main.go
func listenGenerationChannels(wg *sync.WaitGroup, ctx context.Context, generatedPartCh <-chan inference.GeneratedPart, errorCh <-chan error) {
defer wg.Done()
loop := true
spacesAfterEmoji := ""
for loop {
select {
case generatedPart, ok := <-generatedPartCh:
if !ok {
loop = false
appState.waitingRunesExtraStr = ""
fmt.Fprintln(appState.consoleOutWriter)
break
}
if !generatedPart.IsResendOfWaiting {
appState.generatedTokenIds = append(appState.generatedTokenIds, generatedPart.TokenId)
appState.generatedTokens = append(appState.generatedTokens, generatedPart.Token)
}
if len(spacesAfterEmoji) > 0 && len(generatedPart.WaitingRunesExtraStr) == 0 {
// If space characters should be added between the emoji and generatedPart.DecodedString
// which generated at previous iteration, add them
generatedPart.DecodedString = spacesAfterEmoji + generatedPart.DecodedString
spacesAfterEmoji = ""
} else {
// If there is some emoji in the generated string, add space characters between the emoji and waitingRunesExtraStr
spacesAfterEmoji = generateRequiredSpacesAfterEmoji(generatedPart.WaitingRunesExtraStr)
generatedPart.WaitingRunesExtraStr = spacesAfterEmoji + generatedPart.WaitingRunesExtraStr
}
appState.waitingRunesExtraStr = generatedPart.WaitingRunesExtraStr
if generatedPart.AddedToWaiting {
appState.addedToWaitingCount++
} else {
appState.addedToWaitingCount = 0
appState.generatedText += generatedPart.DecodedString
}
appState.generationState = generatedPart.GenerationState
appState.updateOutput()
appState.startTimeToken = time.Now()
case err := <-errorCh:
if err == nil {
continue
}
fmt.Fprintln(appState.consoleOutWriter)
common.GLogger.ConsoleFatal(err)
case <-ctx.Done():
loop = false
}
}
if len(appState.waitingRunesExtraStr) > 0 {
// If there is some emoji in the generated string, add space characters between the emoji and waitingRunesExtraStr
appState.generatedText += generateRequiredSpacesAfterEmoji(appState.waitingRunesExtraStr)
appState.generatedText += appState.waitingRunesExtraStr
appState.updateOutput()
}
}