Optimizing Arcee Foundation Models on Intel CPUs
Explore how to optimize small language models on Intel’s latest CPU, utilizing Arcee AI’s AFM-4.5B and Intel-optimized inference libraries.
Learn how Kimi Delta Attention was distilled into AFM-4.5B using knowledge distillation, long-context training, and Arcee’s open-source DistillKit.
Moonshot AI recently put out a great paper (and an associated model) on an extension of Gated DeltaNet that they have termed Kimi Delta Attention (KDA). The results look super promising, particularly in the now-classic three-to-one interleaved local and global attention hybrid arrangement. The pretrained model they released is great, but they've open sourced both training and inference kernels, so of course I had to do something to play with them.
Pretraining a whole model from scratch just for funsies is still a little too rich for my blood. Inspired by the paper RADLADS, I decided to try to convert AFM-4.5B-Base into a hybrid KDA and full-attention transformer through knowledge distillation, then see how far it generalizes in long context land.
A tiny note on terms before we go too far: when I say “full attention” I mean standard global self-attention. When I say “NoPE” in this post, I mean we removed RoPE and did not replace it with any other positional embedding scheme in those layers.
First order of business was to create the student model, meaning both modeling code for the desired architecture and a set of decently-initialized weights.
Modeling code turned out to be super easy thanks to flash-linear-attention. The Moonshot AI folks contributed kernels and there's a complete layer implementation that is more or less a drop-in fit. The only real code changes needed were to plug that in, rip out RoPE, and add configuration for what layers are KDA vs. full attention.
Initializing weights is a little trickier, but not much.
Obviously for the majority of the weights they can be copied straight through from the teacher to the student. MLP, embeddings, norms, and so forth can be kept exactly the same. The attention parameters need a bit more attention. The additional parameters unique to KDA layers (like A_log, dt_bias, and so forth) are initialized from scratch.
There are q_proj, k_proj, and so forth for Kimi Delta Attention layers, but as currently implemented there isn't an equivalent of GQA for KDA. (You can set num_v_heads, but that doesn't do anything for your key heads and I ran into some crashes when training with it enabled anyway.) So I used the very elegant solution of just repeating the grouped head weights out to MHA-shaped projections. This makes the resulting model about ~5B parameters, up from the original 4.5B, but sure beats thinking any harder about it.
The RADLADS paper sets out a pretty clear and effective three-step pipeline for this sort of distillation: first doing "Attention Hidden State Alignment", which is a distillation targeting hidden state alignment with only the attention parameters trainable, then a full-parameter distillation, finally followed by a fine tune at long context for sequence length generalization.
I found their pipeline highly effective, but after a bunch of experiments ended up settling on a slightly modified version that gave equivalent results for this specific setup while being much faster to iterate on. I collapsed the first two distillation stages into a single one with frozen MLP parameters, using a cosine loss instead of MSE between hidden states, then allowed the long context fine tune to pick up any slack necessary in adjusting the MLP layers.
It's a shame there isn't a handy open-source toolkit for doing this kind of distillation. But partake in a flight of fancy with me for a moment, and imagine we live in a world in which this YAML config suffices to run it:
1project_name: distillkit-afm-kda
2model: arcee-train/afm-4p5b-kdanope-untrained
3trust_remote_code: true
4
5frozen_res: # regular expressions for parameter names to freeze during training
6 - embed_tokens
7 - lm_head
8 - norm\.weight
9 - ^model\.layers\.[0-9]+\.mlp\..*
10
11output_path: /workspace/models/afm-4p5b-kda-hsd
12use_flash_attention: true
13sequence_length: 2048
14
15dataset:
16 train_dataset:
17 repo_id: arcee-train/afm-autodistill-mix-v0
18 split: train
19 seed: 42
20
21loss_functions:
22 - function: cross_entropy
23 weight: 0.2
24 - function: kl
25 weight: 0.2
26 temperature: 2.0
27 - function: hs_cosine # cosine loss between hidden states
28 weight: 0.6
29
30layer_mapping: all
31
32teacher:
33 kind: hf
34 path: arcee-ai/AFM-4.5B-Base
35 kwargs:
36 attn_implementation: flash_attention_2
37 torch_dtype: bfloat16
38
39training_args:
40 dataset_text_field: text
41 packing: True
42 num_train_epochs: 1
43 per_device_train_batch_size: 4
44 gradient_accumulation_steps: 2
45
46 pad_to_multiple_of: 128 # training kernels got unhappy for very short packed sequences
47
48 save_steps: 200
49 save_total_limit: 1
50 logging_steps: 1
51
52 learning_rate: 1.0e-3
53 weight_decay: 0.00
54 warmup_ratio: 0.025
55 lr_scheduler_type: cosine_with_min_lr
56 lr_scheduler_kwargs:
57 min_lr: 1.0e-5
58
59 bf16: true
60 max_grad_norm: 0.5
61 optim: adamw_torch
62
63 gradient_checkpointing: true
64 gradient_checkpointing_kwargs:
65 use_reentrant: falseWhat a wonderful world that would be.
This is surprisingly effective. Even with MLP and embedding parameters frozen, and only about 300 million tokens seen, token-averaged KL divergence between student and teacher converged to around 0.2 on a small held-out slice of the same mix. That's pretty good! I expected to need the full-parameter distillation as well to get any sort of good performance.
But it seemed to be diminishing returns from this point, and fine tuning at 32k context did just about as well. So that's what I did: one-phase distillation, plus about a billion tokens at 32k sequence length, with a 1e-5 learning rate.
For comparison, I also ran this same pipeline with two other variations of the student model:
1. A) AFM-4.5B-KDA-NoPE (Hybrid, interleaved)
Nine blocks of “3 KDA layers + 1 full-attn NoPE layer”.
2. B) AFM-4.5B-KDA-FLP (Front-loaded full attention)
First four layers are full-attn NoPE, then the rest are mixed. Total full-attn NoPE layers stays at nine. FLP stands for First Layer Prefix and it's a perfectly sensible and legible name for an experiment, thank you very much.
3. C) AFM-4.5B-KDA-Only (All KDA)
Nothing but KDA, babey. No full-attention layers.
These were run as a grab bag of standard evals using our usual harness setup. (If you want the exact prompts, seeds, and shot settings, I’ll dump the command lines somewhere convenient and link them.)
First, the obvious: the teacher reigns supreme. This is not too surprising. Even though the students hold the weight advantage, we're cramming an entire new attention architecture in there. The teacher has had 8 trillion tokens to get good and cozy in how it interacts with attention, vs. barely 1 billion for the distilled students. The interesting signal here is the relative amount of the teacher's capacity the students manage to retain.
Secondly, the also obvious: wow the FLP configuration sucks. This was a bit of a surprise to me. It's strictly a reordering of which layers get converted to full attention vs KDA, so I wasn't expecting much of a difference. If anything I thought full attention would give a better reconstruction of the teacher's hidden states in the crucial first few layers. Maybe positional biases are particularly important to the first few layers? Huh.
Setting aside that particularly bad configuration, we can see an obvious difference between the hybrid and all-KDA models. In knowledge-based benchmarks the all-KDA and hybrid models are within statistical spitting distance. In math, though, there's a larger drop in performance for the all-KDA setup.
Does this mean that KDA is inherently worse than a hybrid or full-attention model at math? No, we definitely can't conclude that. Remember that this is a very small amount of compute involved, and it's likely that on a longer training horizon all configurations will recover even a little more performance. Plus benchmarks like GSM8K are highly sensitive to the formatting of training data, and I kinda just shoveled what I had sitting around into a pile. But we definitely can say that the KDA and NoPE hybrid recovered dramatically more performance than the pure-KDA model given the same training data and hyperparameters.
Given we're tinkering with attention parameters, let's look at long-context performance. The hybrid and KDA-only models were trained up to 32k sequence length so we should see good performance up to there. Below are RULER NIAH results (exact-match retrieval success, higher is better):
Firstly, every model got 100% at simple single-needle retrieval even up to 65k. That's great! It's not super impressive or anything but it's a decent signal that we got non-zero length generalization, as neither of the students saw that length in training.
The pure-KDA model pretty much falls apart immediately on multikey. Even at 4k things are a little shaky, and it doesn't get better from there. But it is worth noting that the degradation is much smoother than that of the hybrid model. The hybrid model tends to hit a cliff past 32k and plummet more or less immediately to zero, vs. the steady falloff of pure-KDA.
This actually is more or less exactly what you would expect from a pure state space model, and a more or less inevitable result of the inherent conflict between a fixed-size hidden state and an ever-growing past. From comparing the hybrid and kda results, we can make the interesting inference that full attention in NoPE layers is likely responsible for the crash past 32k. KDA seems to natively generalize to longer sequence lengths (with the caveat above), but NoPE layers really need to have seen at least a few sequences at that length to not fall apart. Neat! Also more or less matches the folk knowledge that you should train that type of model at 2x the sequence length you plan to inference it at.
Okay, now we have a model that kinda-approximates the performance of AFM 4.5B but with a new architecture. Why go to the trouble?
Because it's sicknasty fast.
Kimi Delta Attention is ridiculous. During training, the hybrid model is hitting 8000 tokens per second per H100. Inference is basically free, and KV cache is a fraction of the size. It's delightful.
Also, because it's neat.
Okay I lied, we do live in the world where open source software lets you run distillations with YAML configurations like the one above.
DistillKit started as a set of simple scripts that we released as open source in August of 2024. When Arcee started to pursue knowledge distillation more heavily, we built a completely revamped version internally that focused heavily on offline computation of teacher outputs and reuse across multiple student models. For various reasons, ranging from "maintenance burden" to "not giving away all our competitive IP for free", we kept it private until now. But it's nearly 2026, and giving away all your IP for free is cool again! As part of our Commitment to Open Source™, I'm happy to say we're releasing our internal swiss army knife for distillation as an update to the DistillKit repo.
It supports both online distillation (teacher runs during training) and offline distillation (train from pre-captured teacher outputs), with a logit compression system built around polynomial approximation plus quantization plus bit-packing. It also does the usual distillation stuff you actually need in practice: composable losses (KL, JSD, TVD, ranking losses, hidden state alignment), sparse or dense distributions, HF integration, and all the sharp edges that show up when you try to do this at scale.
If you want to kick the tires immediately, the happy path looks like:
git clone https://github.com/arcee-ai/distillkit.git
cd distillkit
pip install -e .
distillkit config.yamlIf you want to capture your own teacher outputs for offline distillation, there’s an optional install:
pip install -e ".[capture]"That includes a vLLM fork that’s set up for high-throughput logit sampling. It’s a bit long in the tooth and uses the older v0 engine, so for most people it’s probably saner to start with either online distillation or the pre-captured teacher datasets we publish and go from there.
I've probably written my rant about post-hoc weight transformations and low-budget community experimentation a thousand times, so I'll skim over that. Open science good, do more profane alchemical transformations of open models, get weird, and so forth. As takeaways from this post, here are three interesting artifacts for your consumption: