Query
What this position is looking for.
Module 05
Learn the transformer as a computation graph, implement attention, and connect the pieces to ASR, TTS, multimodal alignment, and inference hosting.
Outcome
Core Idea
A transformer layer lets each token ask: "Which other tokens should I read from, and how strongly?" The model projects each token into three vectors:
What this position is looking for.
What each position offers for matching.
The information copied after matching.
Scores are dot products between queries and keys. Softmax turns those scores into weights. The output is a weighted sum of values.
scores = Q @ K.T / sqrt(d_head)
weights = softmax(scores)
output = weights @ V
Dot products get larger as vector dimension grows. Without scaling, softmax can become too sharp early in training, causing tiny gradients. Dividing by sqrt(d_head) keeps score variance more stable.
Implementation
This is not optimized. It is deliberately explicit so every tensor shape is visible.
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
class SingleHeadSelfAttention(nn.Module):
def __init__(self, d_model: int, d_head: int):
super().__init__()
self.q = nn.Linear(d_model, d_head, bias=False)
self.k = nn.Linear(d_model, d_head, bias=False)
self.v = nn.Linear(d_model, d_head, bias=False)
self.out = nn.Linear(d_head, d_model, bias=False)
def forward(self, x: torch.Tensor, causal: bool = False):
# x: [batch, time, d_model]
q = self.q(x) # [batch, time, d_head]
k = self.k(x) # [batch, time, d_head]
v = self.v(x) # [batch, time, d_head]
scores = q @ k.transpose(-2, -1)
scores = scores / math.sqrt(k.shape[-1])
if causal:
time = x.shape[1]
mask = torch.triu(torch.ones(time, time, device=x.device), diagonal=1)
scores = scores.masked_fill(mask.bool(), float("-inf"))
weights = F.softmax(scores, dim=-1)
mixed = weights @ v
return self.out(mixed), weights
x = torch.randn(2, 5, 16)
layer = SingleHeadSelfAttention(d_model=16, d_head=8)
y, attn = layer(x, causal=True)
print(y.shape) # torch.Size([2, 5, 16])
print(attn.shape) # torch.Size([2, 5, 5])
It is how much token position t in batch item b reads from source position s. With a causal mask, positions s greater than t should have zero probability because future tokens are hidden.
Softmax would still assign probability mass to masked positions. Setting masked logits to negative infinity makes their softmax weight exactly zero in the ideal math and effectively zero in floating point.
Multi-Head Attention
A single attention head creates one similarity space. Multiple heads let the layer learn several kinds of lookup at once: local syntax, long-range dependencies, speaker cues, acoustic events, punctuation patterns, or alignment hints.
# Multi-head attention is usually implemented by projecting once:
qkv = linear(x) # [batch, time, 3 * n_heads * d_head]
# Then reshaping:
qkv = qkv.view(batch, time, 3, n_heads, d_head)
q, k, v = qkv.unbind(dim=2)
# Attention runs independently per head, then heads are concatenated.
If d_model is fixed, more heads means each head is narrower. Very narrow heads may have weak representational capacity, while extra heads add projection and memory overhead. Modern efficient models often use grouped-query or multi-query attention to reduce KV-cache cost.
Audio Connection
Audio frames or patches become a sequence. Encoder transformers mix information across time so a phoneme can be recognized using its context. Conformer adds convolution to better capture local acoustic structure.
Text tokens, phonemes, or semantic tokens condition a model that predicts mel frames, durations, or discrete audio tokens. Attention helps align text positions to generated speech.
Cascaded systems use ASR -> text LLM -> TTS. Direct systems may encode speech into semantic/acoustic tokens and decode speech tokens without text as the only bottleneck.
Audio-text systems learn shared spaces where spoken content, written content, and acoustic events can retrieve or condition each other.
Efficiency
During generation, a decoder only needs to compute attention for the new token against past keys and values. Keeping those past keys and values in memory is the KV cache. This makes decoding much faster than recomputing the whole prefix, but it also makes memory grow with:
batch_size * sequence_length * layers * kv_heads * head_dim
For local audio assistants, this means long prompts and large memory windows can slow the system before TTS even starts. Efficient systems manage prompt length, cache memory, batching, quantization, and streaming carefully.
More SQLite memory records increased prompt tokens. The model had to evaluate a longer prefix before generating the spoken answer. The fix was to lower the default memory window, shorten context items, and keep larger windows opt-in.
Checkpoint
Each token builds a query describing what it wants to know, compares it to keys from every token, turns those comparisons into weights, and uses the weights to blend value vectors from the sequence.
Full self-attention compares every time step with every other time step, so the score matrix is time by time. Audio often has many more frames than text has tokens, making quadratic attention expensive.
Encoder attention can usually read the whole input sequence, such as all audio frames. Decoder causal attention can only read previous generated tokens so it cannot cheat by seeing the future.
The user experiences the whole loop: microphone capture, VAD, ASR, prompt construction, LLM inference, TTS, and playback. A better model is not enough if prompt context is too large, ASR sends noise, or TTS overlaps with an older reply.