New architecture: TMT โ€” dynamic graph attention + adaptive depth routing, 29.4 PPL at 48% compute (120M params)

#148
by vigneshwar234 - opened

TemporalMesh Transformer (TMT v3) โ€” New efficient transformer: 29.4 PPL, 48% compute, 5 innovations

Hi everyone! I'm releasing TemporalMesh Transformer (TMT v3), an open-source PyTorch transformer architecture that achieves state-of-the-art efficiency at 120M parameters.

The problem with current transformers

Standard transformers have three hard inefficiencies that haven't been fixed together:

  1. Attention is O(Sยฒ) โ€” quadratic in sequence length
  2. Attention topology is static โ€” fully connected, never adapts to semantic content
  3. Every token uses all N layers regardless of complexity

TMT fixes all three simultaneously

Innovation What it does Cost
Mesh Attention Dynamic kNN graph rebuilt per-layer from cosine similarity O(Sยทk)
Temporal Decay Learned multiplicative attenuation of distant tokens post-softmax ~0 overhead
Adaptive Exit Per-token gate: punctuation exits layer 2, rare words layer 12 โˆ’52% compute
Dual-Stream FFN Parallel syntax + semantic streams, sigmoid fusion Same FLOPs
EMA Anchors 16 persistent fast-weight vectors, cross-sequence recall 32KB params

Key results (120M params, all seeds 42/1337/2024)

  • WikiText-2: 29.4 PPL (vs 42.1 vanilla, 31.8 Mamba, 33.1 RWKV)
  • WikiText-103: 36.1 PPL (vs 51.3 vanilla, 38.4 Mamba)
  • LongBench: 53.4 avg score (vs 51.3 Mamba, 49.8 Longformer)
  • C4: 27.4 PPL, The Pile: 35.8 PPL, OpenWebText: 30.1 PPL
  • Throughput: 138K tokens/sec A100 FP16
  • Superadditive gain: 12.7 PPL improvement vs 8.6 from summing components individually

Quick start

from tmt.model.config import TMTConfig
from tmt.model.model import TMTModel
model = TMTModel(TMTConfig(vocab_size=50257, d_model=512, n_heads=8, n_layers=12))
out = model(tokens)  # out.logits, out.exit_masks, out.graph_edges, out.confidences

Links

Happy to answer questions on the architecture, training, or ablations!

Sign up or log in to comment