跳至内容

Junyi's Lab

LLM Inference on TPU v6e-4: Small Dense, Large MoE, and Large Dense Models

Table of Contents

# Abstract

We benchmark LLM inference on one Google TPU v6e-4 host (four chips, one VM) with four chips in one VM. We use vLLM 0.20.0 with the tpu-inference backend and an fp8 KV cache.

We test three Qwen3 models:

ModelTypeParamsParallelismChips
Qwen3.5-4Bdense4B activetp11
Qwen3-30B-A3BMoE30B total / 3B activetp44
Qwen3-32Bdense32B activetp44

We measure three parts of inference: prefill, decode, and end-to-end online serving.

The MoE model is fastest in all three tests. It reaches 26,063 tok/s in prefill, compared with 19,714 tok/s for the dense 32B model. At batch size 1, its decode latency is 7.0 ms per token, compared with 17.9 ms for the dense 32B model. In online serving, it reaches 1.27 req/s and 1,303 output tok/s, compared with 0.88 req/s and 901 output tok/s for the dense 32B model.

The 4B model has good latency at low load, but it does not handle high concurrency well on one chip. Its stable serving capacity is about 0.45 req/s.

This is not a TPU-versus-GPU comparison. We did not run a GPU baseline. The goal is to report clear TPU v6e numbers using the same vLLM tools that users would run in practice.

# 1. Motivation

Public TPU data for LLM inference is limited. The only TPU submission Google has made to MLPerf Inference is SDXL image generation, not an LLM. The TPU LLM numbers is measured on Google’s JAX/JetStream stack rather than the vLLM path most practitioners deploy.

We measured the performance ourselves on one TPU v6e-4 host. The goal is to give simple, reproducible numbers for three common model types:

  1. a small dense model,
  2. a large MoE model,
  3. a large dense model.

All comparisons are between these three models on the same TPU host. We do not compare against GPUs, prices, or total cost.

# 2. Experimental Setup

All measurements use one TPU v6e-4 host with four chips in one VM. The 4B model uses one chip. The 30B-A3B and 32B models use all four chips. For per-chip numbers, we divide the four-chip throughput by four.

We use vLLM 0.20.0, JAX 0.10.0, libtpu 0.0.40, Python 3.11 with the tpu-inference backend. The KV cache uses fp8_e5m2 in all runs. We report exact versions because TPU performance can change quickly across releases.

ModelTypeTotal / active paramsParallelismChipsConfig
Qwen3.5-4Bdense4B / 4Btp11GDN, MBT2048
Qwen3-30B-A3BMoE30B / 3Btp44GMU0.8, MBT8192
Qwen3-32Bdense32B / 32Btp44GMU0.8, MBT8192
Table 1. Model configurations used in all experiments. GDN = guided decoding off, GMU = GPU memory utilization, MBT = max-num-batched-tokens.

The key comparison is between Qwen3-30B-A3B and Qwen3-32B. They have similar total size, but the MoE model only uses 3B parameters per token, while the dense 32B model uses all 32B.

We run three tests.

For prefill, we use concurrency 1 and output length 1. We vary context length over 512, 1024, 2048, 4096, and 8192 tokens. This measures how fast the model processes the input prompt.

For decode, we fix the output length at 128 tokens and vary batch size over 1, 4, 16, and 64. We mainly report context length 1024, with context length 4096 as a longer-KV case. This measures the cost of generating tokens after prefill.

For online serving, we use vllm bench serve against the OpenAI-compatible endpoint. Input and output lengths are both 1024 tokens. We increase the request rate until the server saturates.

Metrics:

  • TTFT: time to first token. This mainly measures prefill.
  • TPOT: time per output token. This mainly measures decode.
  • E2E latency: full request time.
  • Prefill throughput: context length divided by p50 TTFT.
  • Decode throughput: batch size × 1000 divided by p50 TPOT.
  • Output throughput: generated tokens per second.
  • Request throughput: completed requests per second.

We report p50 unless stated otherwise.

# 3. Results

## 3.1 Prefill

Prefill throughput versus context length on TPU v6e-4
Figure 1. Prefill throughput versus context length. Higher is better. The 30B-A3B and 32B results are total throughput across four chips. The 4B result uses one chip.

The MoE model has the highest prefill throughput.

ModelBest prefill throughputPer-chip peak
Qwen3-30B-A3B26,063 tok/s~6.5k tok/s
Qwen3-32B19,714 tok/s~4.9k tok/s
Qwen3.5-4B4,177 tok/s~4.2k tok/s
Table 2. Peak prefill throughput and per-chip peak for each model.
Modelctx 512ctx 1024ctx 2048ctx 4096ctx 8192
Qwen3-30B-A3B35 ms51 ms84 ms157 ms318 ms
Qwen3-32B38 ms58 ms104 ms208 ms461 ms
Qwen3.5-4B432 ms452 ms495 ms983 ms1,961 ms
Table 3. p50 TTFT at each context length (ms).

The 4B result is limited by its single-chip setup and by max-num-batched-tokens=2048. For contexts above 2048, prefill is split into chunks, so the curve flattens.

## 3.2 Decode

Decode p50 TPOT versus batch size at context 1024 and 4096
Figure 2. Decode p50 TPOT versus batch size. Solid lines: context 1024. Dashed lines: context 4096. Lower is better.
Aggregate decode throughput versus batch size at context 1024 and 4096
Figure 3. Aggregate decode throughput versus batch size. Solid lines: context 1024. Dashed lines: context 4096. Higher is better.
ModelBSp50 TPOTp99 TPOTthroughput
Qwen3-30B-A3B17.0 ms7.1 ms143 tok/s
Qwen3-30B-A3B48.8 ms9.9 ms456 tok/s
Qwen3-30B-A3B1614.9 ms17.5 ms1,071 tok/s
Qwen3-30B-A3B6448.1 ms152.7 ms1,331 tok/s
Qwen3-32B117.9 ms17.9 ms56 tok/s
Qwen3-32B418.2 ms19.4 ms220 tok/s
Qwen3-32B1622.4 ms27.4 ms715 tok/s
Qwen3-32B6458.4 ms65.3 ms1,096 tok/s
Qwen3.5-4B110.7 ms10.7 ms93 tok/s
Qwen3.5-4B416.2 ms18.3 ms247 tok/s
Qwen3.5-4B1641.4 ms647.3 ms387 tok/s
Qwen3.5-4B6480.5 ms97.2 ms795 tok/s
Table 4. Decode TPOT at context length 1024 (p50, p99, and aggregate throughput).
ModelBSp50 TPOTp99 TPOTthroughput
Qwen3-30B-A3B17.0 ms7.1 ms143 tok/s
Qwen3-30B-A3B49.7 ms12.9 ms411 tok/s
Qwen3-30B-A3B1622.3 ms28.5 ms719 tok/s
Qwen3-30B-A3B6496.8 ms213.9 ms661 tok/s
Qwen3-32B119.1 ms19.3 ms53 tok/s
Qwen3-32B421.4 ms25.2 ms187 tok/s
Qwen3-32B1637.7 ms130.6 ms424 tok/s
Qwen3-32B6482.5 ms92.8 ms775 tok/s
Qwen3.5-4B110.8 ms10.8 ms93 tok/s
Qwen3.5-4B425.3 ms34.1 ms158 tok/s
Qwen3.5-4B1685.0 ms89.2 ms188 tok/s
Qwen3.5-4B6485.0 ms89.2 ms753 tok/s
Table 5. Decode TPOT at context length 4096 (p50, p99, and aggregate throughput).
Decode p99/p50 TPOT ratio at context 1024 and 4096
Figure 4. Decode tail latency: p99/p50 TPOT ratio at each batch size. A ratio of 1.0 means p99 equals p50 (no tail). Higher bars indicate worse scheduling irregularity.

At batch size 1, the MoE model is 2.5× faster than the dense 32B model: 7.0 ms versus 17.9 ms at context 1024. This follows active parameter count. Decode is memory-bandwidth-bound. The MoE model reads 3B active parameters per token; the dense 32B model reads all 32B.

The p99/p50 ratio reveals scheduling pressure at high batch sizes. The 30B-A3B at BS64 ctx 1024 has p99 TPOT of 152.7 ms against p50 of 48.1 ms (3.2×). The 4B at BS16 ctx 1024 has p99 TPOT of 647.3 ms against p50 of 41.4 ms (15.6×) — at this operating point the single-chip KV cache is under pressure and requests queue unevenly.

At context 4096, longer KV sequences reduce throughput across all models. The 30B-A3B at BS64 drops from 1,331 tok/s (ctx 1024) to 661 tok/s (ctx 4096). The dense 32B drops less in relative terms (1,096 to 775 tok/s) because its compute is already the bottleneck at ctx 1024. The 4B model saturates at BS16 for ctx 4096: p50 TPOT at BS64 equals BS16 (85.0 ms both), meaning additional concurrency gives no throughput gain on one chip.

## 3.3 Online serving

Online-serving throughput-latency curve
Figure 5. Online serving throughput and latency. The x-axis is output throughput. The y-axis is p50 TPOT. Lower-right is better.

In online serving, higher request rate gives higher throughput, but also higher latency. Each model reaches a point where adding more requests no longer helps.

End-to-end TTFT and TPOT versus output throughput
Figure 6. End-to-end latency sweep. Left: p50 TTFT versus output throughput. Right: p50 TPOT versus output throughput. The × marker on the 4B line marks the concurrency-collapse point under unlimited load.
Req rateActual req/sOutput tok/sTTFT p50TTFT p99TPOT p50TPOT p99E2E p50
0.20.1919790 ms2,682 ms7.9 ms9.0 ms8.2 s
0.40.3636692 ms190 ms9.2 ms10.1 ms9.5 s
0.60.5253794 ms1,331 ms11.2 ms13.8 ms11.6 s
0.80.62635116 ms6,335 ms18.8 ms28.7 ms24.6 s
1.271,303712 ms1,092 ms22.3 ms22.8 ms23.6 s
Table 6. Qwen3-30B-A3B online serving results (ISL = OSL = 1024 tokens).
Req rateActual req/sOutput tok/sTTFT p50TTFT p99TPOT p50TPOT p99E2E p50
0.150.14144115 ms2,243 ms18.0 ms18.3 ms18.5 s
0.30.26271115 ms1,042 ms18.9 ms20.1 ms19.5 s
0.450.37380116 ms295 ms19.5 ms20.9 ms20.2 s
0.60.45466123 ms3,983 ms21.0 ms28.1 ms21.7 s
0.889011,190 ms9,249 ms43.2 ms44.0 ms45.4 s
Table 7. Qwen3-32B online serving results (ISL = OSL = 1024 tokens).
Req rateActual req/sOutput tok/sTTFT p50TTFT p99TPOT p50TPOT p99E2E p50
0.20.19194509 ms33,147 ms12.3 ms15.4 ms13.1 s
0.30.28283507 ms1,185 ms12.8 ms14.3 ms13.7 s
0.40.36367506 ms1,099 ms13.4 ms15.7 ms14.4 s
0.450.40407507 ms1,141 ms13.8 ms15.9 ms14.6 s
0.141436,762 ms219,855 ms133.8 ms237.2 ms206.8 s
Table 8. Qwen3.5-4B online serving results (ISL = OSL = 1024 tokens).

The MoE model (30B-A3B) has a stable working zone up to about 0.6 req/s, where p50 TPOT stays below 12 ms and p99 TTFT stays below 1.4 seconds. At 0.8 req/s the server is already queuing: p99 TTFT jumps to 6.3 seconds while p50 TPOT jumps from 11 ms to 19 ms. Saturation at ∞ rate gives tight p99/p50 TPOT (22.8 ms versus 22.3 ms), meaning the server is consistently loaded with no idle cycles.

The dense 32B model’s TPOT p50 is nearly flat from 0.15 to 0.6 req/s (18–21 ms), then doubles at saturation (43.2 ms). Its TTFT p99 is noisy across the sweep, suggesting occasional prefill bursts even at low load.

The 4B model holds stable from 0.2 to 0.45 req/s. At 0.45 req/s it sustains 0.40 req/s and 407 output tok/s, with TPOT p50 of 13.8 ms and E2E p50 of 14.6 seconds. Under unlimited load it collapses entirely: output throughput drops from 407 tok/s to 143 tok/s, p50 TPOT rises from 13.8 ms to 133.8 ms, and p50 E2E latency reaches 206.8 seconds. The p99 TTFT at unlimited load is 219,855 ms — 220 seconds. This is not a degraded state; it is a broken one. We treat the unlimited-load point as a single-chip concurrency limit, not useful serving capacity.

Modelp50 TTFTp50 E2E latency
Qwen3-30B-A3B712 ms23.6 s
Qwen3-32B1,190 ms45.4 s
Qwen3.5-4B507 ms14.6 s
Table 9. p50 TTFT and p50 E2E latency at the saturation or stable operating point.
MoE-versus-dense efficiency summary
Figure 7. Summary relative to the dense Qwen3-32B baseline. Higher is better.

Across all three tests, the MoE model beats the dense 32B model at similar total parameter count:

MetricQwen3-30B-A3BQwen3-32B
Best prefill throughput26,063 tok/s19,714 tok/s
BS1 decode TPOT7.0 ms17.9 ms
Serving capacity1.27 req/s0.88 req/s
Output throughput at saturation1,303 tok/s901 tok/s
Table 10. Summary comparison of Qwen3-30B-A3B versus Qwen3-32B across all three tests.

The MoE model gives about 2.5× lower single-stream decode latency, about 1.4× higher serving capacity, and about 1.3× higher prefill throughput.

# 4. Discussion

Three points stand out.

First, active parameter count matters a lot for decode. The 30B-A3B MoE and the 32B dense model have similar total size, but very different active size. The MoE model uses 3B parameters per token. The dense model uses 32B. This explains the large decode gap: 7.0 ms versus 17.9 ms at batch size 1.

Second, the MoE model also wins in prefill. Its peak prefill throughput is 26,063 tok/s, or about 6.5k tok/s per chip. The dense 32B model reaches 19,714 tok/s, or about 4.9k tok/s per chip. On this host, the MoE model is faster in both prefill and decode.

Third, the small 4B model is limited by single-chip concurrency. It has good low-load latency, but one chip cannot absorb unlimited requests. Its practical serving capacity is about 0.45 req/s. To serve more traffic, the better path is to add chips or replicas, rather than pushing more concurrency onto one chip.

# 5. Conclusion

On one TPU v6e-4 host, using vLLM 0.20.0, tpu-inference, and an fp8_e5m2 KV cache, the large MoE model is the fastest of the three models we tested.

Compared with the dense 32B model, Qwen3-30B-A3B has:

  • higher prefill throughput: 26,063 vs 19,714 tok/s,
  • lower batch-size-1 decode latency: 7.0 vs 17.9 ms,
  • higher serving capacity: 1.27 vs 0.88 req/s,
  • higher output throughput at saturation: 1,303 vs 901 tok/s.

The 4B dense model has good latency at low load, but on one chip it is concurrency-limited. Its stable serving capacity is about 0.45 req/s.

# Limitations

These results are for one host only: four TPU v6e chips in one VM. We did not test multi-host scaling.

We did not run a GPU baseline, so this study makes no TPU-versus-GPU claim.

The online serving test uses one input/output shape: 1024 input tokens and 1024 output tokens. Other workloads may shift the balance between prefill and decode.

The results are version-specific. TPU support in vLLM changes quickly, so the numbers here apply to vLLM 0.20.0.

All runs use an fp8_e5m2 KV cache. We did not test other KV cache formats.

# Reproducibility

The main numbers are ./data.json

Figures are generated by ./figures/plot.py

Per-model results are at ./rawdata.zip