Cosmos3-Super-Text2Image NVIDIA ModelOpt FP8 Transformer

This repository contains a transformer-only NVIDIA ModelOpt FP8 quantization for nvidia/Cosmos3-Super-Text2Image.

It does not repeat the original model card. Read NVIDIA's model card, prompt-format guidance, license, and safety notes here: nvidia/Cosmos3-Super-Text2Image.

Only transformer/ is provided as a weight artifact. The VAE, scheduler, tokenizers, safety checker, and other components are loaded from the base model.

Recipe

Setting Value
Quantizer NVIDIA ModelOpt
ModelOpt version 0.44.0
Quant type FP8_DEFAULT_CFG
Weight-only True
Compressed True
Quantized modules inserted 2709
Quantization time 1.34s
Compress time 0.45s
Save time 65.99s
Transformer checkpoint size 61.06 GiB

The checkpoint includes ModelOpt state in transformer/modelopt_state.pth.

Assemble The Pipeline

Install ModelOpt in the same environment as Diffusers:

pip install "nvidia_modelopt[hf]"

The current tested runtime requires a small compatibility helper for ModelOpt QTensorWrapper restoration with Diffusers and Accelerate. Important: load the quantized transformer without passing torch_dtype; otherwise Diffusers casts FP8 tensors back to BF16 during state-dict loading.

import json
import torch
from diffusers import Cosmos3OmniPipeline, Cosmos3OmniTransformer
from diffusers.schedulers.scheduling_unipc_multistep import UniPCMultistepScheduler
from modelopt.torch.quantization.qtensor.base_qtensor import QTensorWrapper
import modelopt.torch.opt as mto


def patch_modelopt_qtensor_loader():
    import accelerate.utils.modeling as accelerate_modeling
    import diffusers.models.model_loading_utils as diffusers_loading

    original = accelerate_modeling.set_module_tensor_to_device
    if getattr(original, "_cosmos3_modelopt_patch", False):
        return

    def patched(module, tensor_name, device, value=None, dtype=None, fp16_statistics=None,
                tied_params_map=None, non_blocking=False, clear_cache=True):
        leaf_module = module
        leaf_name = tensor_name
        if "." in tensor_name:
            parts = tensor_name.split(".")
            for part in parts[:-1]:
                leaf_module = getattr(leaf_module, part)
            leaf_name = parts[-1]
        old_value = getattr(leaf_module, leaf_name) if hasattr(leaf_module, leaf_name) else None
        if isinstance(old_value, QTensorWrapper) and value is not None:
            leaf_module._parameters[leaf_name] = QTensorWrapper(
                value.to(device, non_blocking=non_blocking),
                metadata=old_value.metadata,
            )
            return
        return original(module, tensor_name, device, value, dtype, fp16_statistics,
                        tied_params_map, non_blocking, clear_cache)

    patched._cosmos3_modelopt_patch = True
    accelerate_modeling.set_module_tensor_to_device = patched
    diffusers_loading.set_module_tensor_to_device = patched


def cast_modelopt_runtime_tensors(model, dtype=torch.bfloat16):
    for module in model.modules():
        for name, param in list(module._parameters.items()):
            if isinstance(param, QTensorWrapper):
                param.metadata["dtype"] = dtype
            elif param is not None and param.is_floating_point():
                module._parameters[name] = torch.nn.Parameter(
                    param.detach().to(dtype),
                    requires_grad=param.requires_grad,
                )
        for name, buf in list(module._buffers.items()):
            if buf is not None and buf.is_floating_point():
                module._buffers[name] = buf.to(dtype)
    return model


patch_modelopt_qtensor_loader()
mto.enable_huggingface_checkpointing()

transformer = Cosmos3OmniTransformer.from_pretrained(
    "WaveCut/Cosmos3-Super-Text2Image-ModelOpt-FP8-Transformer",
    subfolder="transformer",
    use_safetensors=False,
)
transformer = cast_modelopt_runtime_tensors(transformer, torch.bfloat16)

pipe = Cosmos3OmniPipeline.from_pretrained(
    "nvidia/Cosmos3-Super-Text2Image",
    transformer=transformer,
    torch_dtype=torch.bfloat16,
    device_map="cuda",
    enable_safety_checker=True,
)
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config, flow_shift=3.0)
pipe.to("cuda")

json_caption = {
    "subjects": [],
    "background_setting": "A concise scene description.",
    "comprehensive_t2i_caption": "A detailed natural-language caption.",
    "resolution": {"H": 1024, "W": 1024},
    "aspect_ratio": "1,1",
}

with torch.autocast("cuda", dtype=torch.bfloat16):
    result = pipe(
        prompt=json.dumps(json_caption),
        negative_prompt="",
        num_frames=1,
        height=1024,
        width=1024,
        num_inference_steps=50,
        guidance_scale=4.0,
        generator=torch.Generator(device="cuda").manual_seed(1143),
    )
result.video[0].save("cosmos3_modelopt_fp8.png")

Benchmarks

Measured on one RunPod NVIDIA B200 instance with local container storage, cached model files, PyTorch 2.9.1+cu130, 1024x1024 image generation, 50 inference steps, guidance scale 4.0, flow_shift=3.0, system prompt enabled. The ModelOpt FP8 runtime uses BF16 autocast around the pipeline forward.

Transformer Component Load

Variant Load to CUDA VRAM after load Torch allocated Torch reserved Transformer weights
BF16 base transformer 41.83s 122,758 MiB 122,121 MiB 122,132 MiB 119.21 GiB
NVIDIA ModelOpt FP8 transformer 21.95s 63,550 MiB 62,907 MiB 62,924 MiB 61.06 GiB

Full Pipeline Generation

The stress set is ten handwritten JSON-caption prompts designed to stress Cyrillic text, reflections, multi-object composition, anatomy, small details, and scene-following.

Variant Full pipeline load VRAM after load Torch allocated after load Avg generation time Min / max generation time Peak sampled VRAM Images
BF16 base pipeline 31.31s 125,134 MiB 124,386 MiB 16.05s 15.51s / 17.97s 141,104 MiB 10
NVIDIA ModelOpt FP8 pipeline 35.49s 65,810 MiB 65,171 MiB 45.57s 45.07s / 47.28s 81,854 MiB 10

Original NVIDIA Example Caption

The original model repository provides assets/example_caption.json. The images below are generated locally with the same JSON-caption, seed 1143, 1024x1024, 50 steps, guidance scale 4.0.

Variant Pipeline load Generation time Peak sampled VRAM
BF16 base pipeline 35.41s 18.01s 141,098 MiB
NVIDIA ModelOpt FP8 pipeline 35.28s 47.20s 71,470 MiB

BF16 reference output:

BF16 output for NVIDIA example caption

NVIDIA ModelOpt FP8 output:

NVIDIA ModelOpt FP8 output for NVIDIA example caption

Stress Prompt Outputs

Stress prompt NVIDIA ModelOpt FP8 output
01 metro archive reading room 01 metro archive reading room
02 arctic greenhouse night shift 02 arctic greenhouse night shift
03 control room restoration 03 control room restoration
04 rain market cross section 04 rain market cross section
05 manuscript restoration table 05 manuscript restoration table
06 robotic assembly line signage 06 robotic assembly line signage
07 kitchen storm chess table 07 kitchen storm chess table
08 orbital cockpit cyrillic ui 08 orbital cockpit cyrillic ui
09 flood command center 09 flood command center
10 cyrillic newspaper press 10 cyrillic newspaper press

Notes

  • Treat this as an experimental ModelOpt FP8 transformer artifact. The upstream NVIDIA card documents BF16 as the tested precision.
  • Do not pass torch_dtype=torch.bfloat16 when loading this quantized transformer; cast runtime metadata after loading as shown above.
  • The safety checker is not included in this repository; load it from the base model if your use case requires it.
  • Text rendering, especially exact Cyrillic text, remains a hard case for this model family and should be evaluated visually for the target prompt distribution.
Downloads last month
-
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Model tree for WaveCut/Cosmos3-Super-Text2Image-ModelOpt-FP8-Transformer

Finetuned
(4)
this model

Collection including WaveCut/Cosmos3-Super-Text2Image-ModelOpt-FP8-Transformer