Official code for the CVPR 2026 paper Weight Space Representation Learning via Neural Field Adaptation by Zhuoqian Yang, Mathieu Salzmann, and Sabine Süsstrunk (EPFL).
Paper | Project Page | ShapeNetSDF dataset
- Properly constrained weights exhibit semantic structure and serve as effective data representations.
- Weight-space geometry correlates strongly with diffusion generation quality.
The pipeline (Algorithm 1 in the paper) spans two repos that sit side by side:
| Repo | Pipeline role |
|---|---|
neural_field |
Stage 1: base neural field training (variational autodecoder, 2D and 3D). Stage 2: per-instance fitting (MLP / LoRA / mLoRA, with optional asymmetric masking) to build the weight-space dataset. Reconstruction, weight-space-structure, and discriminative (classification / clustering) evaluation. |
weight_space_diffusion |
Stage 3: Diffusion Transformer trained on the weight representations, with a hierarchical LoRA layer encoder. Sampling and generation metrics. |
neural_field/
├── apps/ # entry-point scripts
│ ├── train_2d_vae.py # Stage 1: train 2D base field (FFHQ), variational autodecoder
│ ├── train_3d_vae.py # Stage 1: train 3D base field (ShapeNet, occupancy/SDF)
│ ├── overfit.py # Stage 2: fit per-instance MLP / LoRA / mLoRA weights
│ ├── infer_overfitted.py # Stage 2: reconstruction quality (Table 1)
│ ├── cluster_weights.py # discriminative eval: clustering + classification (Table 4)
│ ├── calculate_fid.py # distributional metrics FD / MMD-G / MMD-P
│ ├── demo_3d.py # extract / visualize meshes via marching cubes
│ ├── render_meshes.py # render meshes (PyTorch3D): multi-view or turntable
│ └── sanity_check_data.py # quick dataset sanity check
├── model/ # neural field architecture
│ ├── Generators.py # NFres / NFresAdapter (2D modulated field + LoRA adapter)
│ ├── Generator3D.py # NFres3D / NFres3DAdapter (3D field + LoRA adapter)
│ ├── MLP.py # standalone MLP field (the non-adapter representation)
│ ├── blocks.py # synthesis blocks: modulated conv / linear, Fourier features
│ ├── adapter_blocks.py # additive and multiplicative (mLoRA) LoRA adapters
│ ├── loramod_blocks.py # base modulated conv / linear the adapters wrap
│ ├── asym_mask_utils.py # asymmetric masking (-Asym): breaks LoRA rank-permutation symmetry
│ ├── dataset2d.py # FFHQ image-point dataset
│ ├── dataset3d.py # ShapeNet SDF / occupancy point dataset
│ ├── mesh_renderer.py # PyTorch3D mesh rendering
│ ├── point_cloud_utils.py # surface sampling + point-cloud metrics
│ └── chamfer_distance.py # Chamfer distance (3D reconstruction metric)
├── op/ # fused-bias / upfirdn2d ops (native PyTorch fallback)
├── scripts/ # data preparation
├── distributed.py # multi-GPU helpers (torchrun)
├── util.py # misc training utilities
├── tensor_transforms.py # coordinate / tensor transforms
├── data/ # datasets live here (gitignored)
weight_space_diffusion/
├── apps/ # entry-point scripts
│ ├── main.py # Stage 3: train the Diffusion Transformer on weights (DDPM)
│ ├── inference.py # sample weights via DDIM and decode them (multi-GPU)
│ ├── evaluate_metrics.py # FD / MMD-G / MMD-P + 3D mMD / COV / 1-NNA (Tables 2 to 3)
│ ├── novelty_check.py # nearest-training-neighbor retrieval (Fig. 9)
│ └── calculate_normalization_factor.py # per-dataset weight-normalization stats -> config
├── configs/diffusion_configs/ # 21 configs = 6 representations x 3 datasets + ablation
├── lib/ # diffusion model + training
│ ├── hyperdiffusion.py # LightningModule: training / sampling loop
│ ├── transformer.py # Diffusion Transformer (GPT-style) backbone
│ ├── embedder.py # hierarchical LoRA layer encoder
│ ├── utils.py # config + worker base classes
│ └── diffusion/ # DDPM / DDIM core (gaussian_diffusion, losses, nn)
├── dataio/
│ ├── dataset.py # weight dataset (+ point dataset for 3D eval)
│ └── atomizer.py # flatten / unflatten state_dicts to and from token sequences
├── metric/
│ ├── clip_embedder.py # 2D CLIP feature extractor
│ ├── pointnet2_embedder.py # 3D PointNet++ feature extractor
│ ├── distribution_metric.py # FD and MMD (MMD-G, MMD-P)
│ └── evaluation_metrics_3d.py # 3D distance metrics: mMD, COV, 1-NNA
├── external/
│ ├── nf/utils.py # decode generated weights -> neural field (imports neural_field/model)
│ └── pointnet2/ # PointNet++ extractor + checkpoints/
- FFHQ: first 5,000 face images at 128 x 128.
- ShapeNet airplane: all 4,045 models.
- ShapeNet multi: the 10 largest categories, 500 instances each.
Preprocessed ShapeNetSDF data can be downloaded here.
- Install the environment: see INSTALL.md.
- Reproduce experiments: see REPRODUCE.md.
@inproceedings{yang2026wsr,
title = {Weight Space Representation Learning via Neural Field Adaptation},
author = {Yang, Zhuoqian and Salzmann, Mathieu and S{\"u}sstrunk, Sabine},
booktitle = {Proceedings of the IEEE/CVF Conference on
Computer Vision and Pattern Recognition (CVPR)},
year = {2026}
}