Skip to content

IVRL/wsr.pytorch

Repository files navigation

Weight Space Representation Learning via Neural Field Adaptation

Official code for the CVPR 2026 paper Weight Space Representation Learning via Neural Field Adaptation by Zhuoqian Yang, Mathieu Salzmann, and Sabine Süsstrunk (EPFL).

Paper | Project Page | ShapeNetSDF dataset

teaser

TL;DR

  1. Properly constrained weights exhibit semantic structure and serve as effective data representations.
  2. Weight-space geometry correlates strongly with diffusion generation quality.

Repository structure

The pipeline (Algorithm 1 in the paper) spans two repos that sit side by side:

Repo Pipeline role
neural_field Stage 1: base neural field training (variational autodecoder, 2D and 3D). Stage 2: per-instance fitting (MLP / LoRA / mLoRA, with optional asymmetric masking) to build the weight-space dataset. Reconstruction, weight-space-structure, and discriminative (classification / clustering) evaluation.
weight_space_diffusion Stage 3: Diffusion Transformer trained on the weight representations, with a hierarchical LoRA layer encoder. Sampling and generation metrics.

neural_field/ (Stages 1 to 2)

neural_field/
├── apps/                          # entry-point scripts
│   ├── train_2d_vae.py            # Stage 1: train 2D base field (FFHQ), variational autodecoder
│   ├── train_3d_vae.py            # Stage 1: train 3D base field (ShapeNet, occupancy/SDF)
│   ├── overfit.py                 # Stage 2: fit per-instance MLP / LoRA / mLoRA weights
│   ├── infer_overfitted.py        # Stage 2: reconstruction quality (Table 1)
│   ├── cluster_weights.py         # discriminative eval: clustering + classification (Table 4)
│   ├── calculate_fid.py           # distributional metrics FD / MMD-G / MMD-P
│   ├── demo_3d.py                 # extract / visualize meshes via marching cubes
│   ├── render_meshes.py           # render meshes (PyTorch3D): multi-view or turntable
│   └── sanity_check_data.py       # quick dataset sanity check
├── model/                         # neural field architecture
│   ├── Generators.py              # NFres / NFresAdapter (2D modulated field + LoRA adapter)
│   ├── Generator3D.py             # NFres3D / NFres3DAdapter (3D field + LoRA adapter)
│   ├── MLP.py                     # standalone MLP field (the non-adapter representation)
│   ├── blocks.py                  # synthesis blocks: modulated conv / linear, Fourier features
│   ├── adapter_blocks.py          # additive and multiplicative (mLoRA) LoRA adapters
│   ├── loramod_blocks.py          # base modulated conv / linear the adapters wrap
│   ├── asym_mask_utils.py         # asymmetric masking (-Asym): breaks LoRA rank-permutation symmetry
│   ├── dataset2d.py               # FFHQ image-point dataset
│   ├── dataset3d.py               # ShapeNet SDF / occupancy point dataset
│   ├── mesh_renderer.py           # PyTorch3D mesh rendering
│   ├── point_cloud_utils.py       # surface sampling + point-cloud metrics
│   └── chamfer_distance.py        # Chamfer distance (3D reconstruction metric)
├── op/                            # fused-bias / upfirdn2d ops (native PyTorch fallback)
├── scripts/                       # data preparation
├── distributed.py                 # multi-GPU helpers (torchrun)
├── util.py                        # misc training utilities
├── tensor_transforms.py           # coordinate / tensor transforms
├── data/                          # datasets live here (gitignored)

weight_space_diffusion/ (Stage 3)

weight_space_diffusion/
├── apps/                          # entry-point scripts
│   ├── main.py                    # Stage 3: train the Diffusion Transformer on weights (DDPM)
│   ├── inference.py               # sample weights via DDIM and decode them (multi-GPU)
│   ├── evaluate_metrics.py        # FD / MMD-G / MMD-P + 3D mMD / COV / 1-NNA (Tables 2 to 3)
│   ├── novelty_check.py           # nearest-training-neighbor retrieval (Fig. 9)
│   └── calculate_normalization_factor.py  # per-dataset weight-normalization stats -> config
├── configs/diffusion_configs/     # 21 configs = 6 representations x 3 datasets + ablation
├── lib/                           # diffusion model + training
│   ├── hyperdiffusion.py          # LightningModule: training / sampling loop
│   ├── transformer.py             # Diffusion Transformer (GPT-style) backbone
│   ├── embedder.py                # hierarchical LoRA layer encoder
│   ├── utils.py                   # config + worker base classes
│   └── diffusion/                 # DDPM / DDIM core (gaussian_diffusion, losses, nn)
├── dataio/
│   ├── dataset.py                 # weight dataset (+ point dataset for 3D eval)
│   └── atomizer.py                # flatten / unflatten state_dicts to and from token sequences
├── metric/
│   ├── clip_embedder.py           # 2D CLIP feature extractor
│   ├── pointnet2_embedder.py      # 3D PointNet++ feature extractor
│   ├── distribution_metric.py     # FD and MMD (MMD-G, MMD-P)
│   └── evaluation_metrics_3d.py   # 3D distance metrics: mMD, COV, 1-NNA
├── external/
│   ├── nf/utils.py                # decode generated weights -> neural field (imports neural_field/model)
│   └── pointnet2/                 # PointNet++ extractor + checkpoints/

Datasets

  • FFHQ: first 5,000 face images at 128 x 128.
  • ShapeNet airplane: all 4,045 models.
  • ShapeNet multi: the 10 largest categories, 500 instances each.

Preprocessed ShapeNetSDF data can be downloaded here.

Getting started

  1. Install the environment: see INSTALL.md.
  2. Reproduce experiments: see REPRODUCE.md.

Citation

@inproceedings{yang2026wsr,
  title     = {Weight Space Representation Learning via Neural Field Adaptation},
  author    = {Yang, Zhuoqian and Salzmann, Mathieu and S{\"u}sstrunk, Sabine},
  booktitle = {Proceedings of the IEEE/CVF Conference on
               Computer Vision and Pattern Recognition (CVPR)},
  year      = {2026}
}

About

Official PyTorch implementation of "Weight Space Representation Learning via Neural Field Adaptation" (CVPR 2026).

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors