pravsels's picture
Add model card
bba62c2 verified
|
Raw
History Blame Contribute Delete
2.44 kB
metadata
tags:
  - robotics
  - vla
  - rl-token

pi05-so101-armnetbench-tool-insert-rlt-v50

RL Token (RLT) encoder-decoder trained on the SO101 tool insertion task, on top of the step_24999 checkpoint from lorenzouttini/pi05-so101-armnetbench-tool-insert-isambard-v50.

What is this?

This model is a lightweight transformer encoder-decoder which takes inputs from a frozen Pi-05 VLA backbone. The encoder compresses the VLA final-layer prefix embeddings into a single RL token via a learned query. The decoder autoregressively reconstructs the original embeddings from only this token, forcing it to act as an information bottleneck. See Xu et al. (2026), Precise Manipulation with Efficient Online RL for the method.

Training

  • Config: pi05_rlt_armnetbench_tool_insert
  • VLA backbone: lorenzouttini/pi05-so101-armnetbench-tool-insert-isambard-v50 step_24999 (frozen, rl_vla_loss_weight=0.0)
  • Encoder-decoder: 2-layer transformer, 8 heads, 8192 MLP dim, 2048 embedding dim
  • Dataset: villekuosmanen/armnetbench_tool_insert
  • Batch size: 32
  • LR: 2.5e-5 cosine (1k warmup, 20k decay)
  • Steps: 20,000
  • Runtime: ~4h14m on 4x GH200 (Isambard)

No validation split was used — the dataset is too small for a held-out eval split.

Loss progression (train)

Step Train Loss
0 10754.3
1,000 873.7
5,000 552.0
10,000 430.0
15,000 377.6
19,900 356.4

Checkpoints

Step Recommended Params SHA256
19999 d9ddbbbefc07b3700f7df5dd161c14de3291bfcf805f71e39ababb902e1501b2

Verifying checkpoint hashes

cd checkpoints/19999 && find params -type f | sort | xargs sha256sum | sha256sum

Repo layout

assets/                         # Norm stats, valid indices
checkpoints/19999/params/       # Step 19999 model weights (recommended)

W&B

Training curves: https://wandb.ai/pravsels/pi05_rlt_armnetbench_tool_insert/runs/jbdqrmu0

Usage

import openpi.models.model as _model
import openpi.training.config as _config

config = _config.get_config("pi05_rlt_armnetbench_tool_insert")
params = _model.restore_params("checkpoints/19999/params", restore_type=np.ndarray)
model = config.model.load(params)