{"slug": "how-to-optimize-transformer-based-models-for-low-precision-training", "title": "How to Optimize Transformer-Based Models for Low-Precision Training", "summary": "NVIDIA released a guide showing how to optimize transformer-based models for low-precision training using Hopper and Blackwell GPUs, focusing on FP8 and NVFP4 formats. The method translates model configurations into GEMM shapes for benchmarking, enabling teams to estimate speedups before committing to expensive training runs. The approach is demonstrated on CodonFM, a 5B-parameter RNA language model.", "body_md": "Transformer architectures are the backbone of many modern large language and generative AI models. As these models grow in size, training runs consume more GPU hours and more engineering iteration time. Accelerating transformers is therefore not just a performance optimization, but directly affects how quickly teams can experiment and how large a model they can afford to train. [NVIDIA Hopper](https://www.nvidia.com/en-us/data-center/technologies/hopper-architecture/) and [NVIDIA Blackwell](https://www.nvidia.com/en-us/data-center/technologies/blackwell-architecture/) GPUs help solve this problem by introducing low-precision operator support including FP8 and [NVFP4](https://developer.nvidia.com/blog/introducing-nvfp4-for-efficient-and-accurate-low-precision-inference/).\n\nTransformers spend much of their training time in GEMMs, and low-precision formats speed up training mainly by making those matrix multiplications faster and cheaper. However, your transformer config does not tell you which GEMMs are actually running in your model. If you want to understand where training time goes, you need to turn your transformer config and batch size into the exact M×K×N matrix shapes your model executes, then benchmark those shapes across precisions. This will help you determine the optimal precision for your architecture before committing to a more expensive training run.\n\n[NVIDIA Transformer Engine (TE)](https://github.com/NVIDIA/TransformerEngine) can handle quantization and kernel dispatch unlocking low precision formats. This post shows you how to move from high-level model settings to concrete GEMM workloads, profile them with a microbenchmark, and estimate where lower precision will actually translate into speedups to help you accelerate your transformer-based models. The use case features CodonFM, a language model for biology focused on RNA.\n\n## Model configuration and training inputs\n\nSuppose you’re working with a 5B-parameter model such as [CodonFM](https://github.com/NVIDIA-BioNeMo/CodonFM) 5B. It will have a config such as:\n\n```\nhidden_size: 4096\nintermediate_size: 16384\nnum_attention_heads: 32\nnum_hidden_layers: 24\n```\n\nYour training configuration is:\n\n```\nmicro_batch_size: 31\nsequence_length: 512\n```\n\nThe benchmark tool can then take these hyperparameters directly and then use a single command to derive GEMM shapes, benchmark them across precisions, and compute the full speedup analysis:\n\n```\npython benchmark.py \\\n  --hidden_size 4096 \\\n  --intermediate_size 16384 \\\n  --num_attention_heads 32 \\\n  --num_hidden_layers 24 \\\n  --micro_batch_size 31 \\\n  --sequence_length 512 \\\n  -o ./images/b300_model_config_speedup.png\n```\n\nNote: To disable Blackwell-specific flags, add `--no-fp8 --no-fp4`\n\n. `--no-fp8 --no-fp4`\n\nprovides BF16 plus the three tensor-wise FP8 recipes that work on Hopper.\n\n`--no-fp8`\n\ndisables MXFP8`--no-fp4`\n\ndisables NVFP4\n\n## Using autocast mode versus prequantizing\n\nBy default, the tool runs in autocast mode, which is what TE does during training: inputs are dynamically quantized to the target precision before each GEMM, so the measured time includes both the quantization cost and the GEMM kernel itself. This provides you with the realistic per-GEMM picture during a training step.\n\nThe tool computes M = 31 × 512 = 15,872 tokens, derives all 12 GEMM shapes, benchmarks each across enabled precisions, and prints the full results. Fprop, Dgrad, and Wgrad shapes are all benchmarked separately to capture the impact of different matrix aspect ratios on kernel selection.\n\nBy default, the tool runs in autocast mode, which is what TE does during training: inputs are dynamically quantized to the target precision before each GEMM, so the measured time includes both the quantization cost and the GEMM kernel itself. This provides you with the realistic per-GEMM picture during a training step.\n\nThe tool computes M = 31 × 512 = 15,872 tokens, derives all 12 GEMM shapes, benchmarks each across enabled precisions, and prints the full results. Fprop, Dgrad, and Wgrad shapes are all benchmarked separately to capture the impact of different matrix aspect ratios on kernel selection.\n\nTo isolate raw GEMM kernel performance, add `--pre-quantize`\n\n. This prequantizes all inputs once before the timed loop, so the measured time reflects only the GEMM kernel execution—no dynamic quantization, no block scaling computation, no format conversion during the timed region.\n\nNote that FP8 DelayedScaling always runs in autocast mode, even with `--pre-quantize`\n\nbecause it relies on an amax history that requires dynamic quantization. Its times are therefore not directly comparable to other precisions in prequantized mode.\n\n```\npython benchmark.py \\\n  --hidden_size 4096 \\\n  --intermediate_size 16384 \\\n  --num_attention_heads 32 \\\n  --num_hidden_layers 24 \\\n  --micro_batch_size 31 \\\n  --sequence_length 512 \\\n  --pre-quantize \\\n  -o ./images/b300_model_config_speedup_prequant.png\n```\n\nComparing the autocast and prequantized speedups tells you exactly how much quantization overhead costs: NVFP4 versus BF16 goes from 1.98x (autocast) to 3.48x (kernel-only). The gap between these two numbers is the overhead from dynamic quantization, Hadamard transforms, and block scaling that occurs in each training step.\n\nUse autocast results for predicting real training speedups. This is what TE actually does during training. Use prequantized results to understand whether quantization overhead is the bottleneck, or to compare raw tensor core throughput across precisions independent of the quantization implementation.\n\n## Interpreting the results for a real model\n\nThis section walks through how to interpret these results for a real model. Using the same CodonFM 5B config, we ran the full model config benchmark on NVIDIA B300. The per-shape NVFP4 versus MXFP8 speedups from the Fprop results are as follows:\n\n```\nQKV proj:   0.579 / 0.392  =  1.48x\nAttn out:   0.269 / 0.256  =  1.05x  (barely faster — overhead nearly matches GEMM gain)\nMLP up:     0.924 / 0.635  =  1.46x\nMLP down:   1.076 / 0.649  =  1.66x\n```\n\nTake note of the following points:\n\n- The attention output GEMM receives minimal benefit from lower precision. Compared with the MXFP8 baseline, there is only a 1.05x speedup. This is the smallest weight matrix in the layer (4096×4096)—barely large enough for lower precision to overcome the overhead. By contrast, the much larger MLP Down GEMM delivers 1.66x NVFP4 over MXFP8 on the same hardware. The MLP down GEMM is big enough to amortize the quantization overhead, where attention output isn’t.\n- The big GEMMs show real but subtheoretical gains. The FP4 tensor cores deliver 1.46x to 1.66x over MXFP8 on the large GEMMs. This is well short of the theoretical 2x to 3x from the hardware spec. Once you include the attention output GEMM, the blended Fprop speedup drops to 1.47x. After adding Wgrad times, non-GEMM overhead and NVFP4-specific quantization costs, the end-to-end gap between NVFP4 and MXFP8 in training is consistent with these kernel-level numbers.\n- FP8 DelayedScaling is surprisingly competitive on NVIDIA Blackwell. At 7.80 ms/layer in autocast mode, it outperforms both FP8 CurrentScaling (9.15 ms) and MXFP8 (8.98 ms). In prequantized mode FP8 CurrentScaling pulls ahead (6.81 ms versus 8.12 ms), suggesting the DelayedScaling amax-history approach has lower quantization overhead but similar raw kernel throughput. This is a good example of the comparison between autocast and prequantized surfacing different winners depending on whether you measure with or without the quantization tax.\n- The prequantized results reveal the true kernel potential. Running with\n`--pre-quantize`\n\nremoves quantization overhead entirely, and NVFP4 versus BF16 jumps from 1.98x (autocast) to 3.48x (kernel-only). This shows the FP4 tensor cores are delivering real speedups. It’s the quantization overhead in autocast mode that narrows the gap. - The Fprop versus Dgrad comparison reveals that the 2x approximation is imprecise for quantized formats. While BF16 Dgrad is within 2% of Fprop, quantized formats show 5–13% slower Dgrad sums. The QKV Proj Dgrad is especially asymmetric—33–51% slower than Fprop for FP8/FP4—because swapping K (4096) and N (12288) dramatically changes the matrix aspect ratio and kernel selection. This is exactly why the tool benchmarks Fprop and Dgrad separately rather than counting Fprop time twice.\n\nOnce you have the estimated GEMM-only speedup, compare it against your observed end-to-end training speedup:\n\n**GEMM speedup ≈ training speedup**: GEMMs dominate the step, everything is working as expected** GEMM speedup >> training speedup**: Overhead outside of GEMMs is eating the gains. For NVFP4 in particular, this overhead includes Random Hadamard transforms on Wgrad inputs, stochastic rounding on gradients, 2D block scaling for weights, and the extra memory pass for per-tensor amax computation. These are all additional ops that MXFP8 doesn’t need, and they can significantly narrow the gap even if the raw FP4 GEMMs are much faster**GEMM speedup ≈ 1.0** even in the microbenchmark. The FP4 kernels aren’t actually faster at these shapes, or they’re silently falling back to FP8\n\nThe last case is especially worth checking. Set `NVTE_LOG_LEVEL=1`\n\nor inspect with [NVIDIA Nsight Systems](https://developer.nvidia.com/nsight-systems) to confirm that TE is actually dispatching FP4 kernels. TE can silently fall back to FP8 or BF16 for layers or ops that don’t support FP4 yet, which would explain identical performance with no other symptoms. You can also compare GPU memory usage between MXFP8 and NVFP4 runs. If memory is nearly identical, that’s a strong signal that FP4 weights aren’t actually being stored.\n\n## Get started benchmarking your model for low-precision training\n\nLow-precision training speedups are highly dependent on the actual GEMM shapes your model runs and running in low precision does not automatically translate into end-to-end training gains, especially when quantization overhead, kernel selection, and non-GEMM operations are included. By turning a transformer config into concrete M×K×N workloads, you can benchmark BF16, MXFP8, and NVFP4 on the shapes that matter for your model before committing to a full training run.\n\nBenchmark your GEMMs to see which precision is right for you. To get started, check out the [benchmark script](https://github.com/NVIDIA/TransformerEngine/blob/main/benchmarks/gemm/benchmark_gemm.py). For the full documentation and to understand how these shapes are derived, see the [GEMM profiling tutorial](https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/examples/gemm_profiling/gemm_profiling.html) in the Transformer Engine documentation.\n\nUse this benchmark to:\n\n- Autocast results to set realistic training-speedup expectations\n- Prequantize results to know whether you’re bottlenecked on kernels or on quantization\n- Run candidate model configs through the tool before committing to a training run, as the tool is a useful architecture co-design instrument", "url": "https://wpnews.pro/news/how-to-optimize-transformer-based-models-for-low-precision-training", "canonical_source": "https://developer.nvidia.com/blog/how-to-optimize-transformer-based-models-for-low-precision-training/", "published_at": "2026-06-16 16:00:00+00:00", "updated_at": "2026-06-16 16:26:08.983088+00:00", "lang": "en", "topics": ["large-language-models", "generative-ai", "ai-infrastructure", "ai-chips", "ai-tools"], "entities": ["NVIDIA", "Hopper", "Blackwell", "Transformer Engine", "CodonFM", "FP8", "NVFP4", "GEMM"], "alternates": {"html": "https://wpnews.pro/news/how-to-optimize-transformer-based-models-for-low-precision-training", "markdown": "https://wpnews.pro/news/how-to-optimize-transformer-based-models-for-low-precision-training.md", "text": "https://wpnews.pro/news/how-to-optimize-transformer-based-models-for-low-precision-training.txt", "jsonld": "https://wpnews.pro/news/how-to-optimize-transformer-based-models-for-low-precision-training.jsonld"}}