Skip to content

Commit 203ecd7

Browse files
fix(envs): preserve original Fetch MJCF visuals on unmodified resets (galilai-group#202)
_apply_visual_variations() ran on every reset, overwriting the original skybox/floor/table/object with the wrapper's init_value defaults even when no variation was requested. Now only variations whose key is in active_variations are applied. Also fixes broken <img> paths in docs/envs/gymnasium_robotics.md (../assets -> ../../assets, needed under use_directory_urls: true) and refreshes the four default Fetch gifs against the fix.
1 parent 1bd0490 commit 203ecd7

6 files changed

Lines changed: 108 additions & 81 deletions

File tree

docs/assets/fetch_pickandplace.gif

89.9 KB
Loading

docs/assets/fetch_push.gif

31.8 KB
Loading

docs/assets/fetch_reach.gif

164 KB
Loading

docs/assets/fetch_slide.gif

46.7 KB
Loading

docs/envs/gymnasium_robotics.md

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -36,10 +36,10 @@ world = swm.World('swm/FetchPickAndPlace-v3', num_envs=4, image_shape=(224, 224)
3636
## Fetch Manipulation Suite
3737

3838
<div style="display: flex; gap: 10px; margin-bottom: 20px;">
39-
<img src="../../assets/env/gymnasium_robotics/fetch/push.gif" alt="fetch push" style="width: 24%; object-fit: contain;">
40-
<img src="../../assets/env/gymnasium_robotics/fetch/slide.gif" alt="fetch slide" style="width: 24%; object-fit: contain;">
41-
<img src="../../assets/env/gymnasium_robotics/fetch/pick_and_place.gif" alt="fetch pick and place" style="width: 24%; object-fit: contain;">
42-
<img src="../../assets/env/gymnasium_robotics/fetch/reach.gif" alt="fetch reach" style="width: 24%; object-fit: contain;">
39+
<img src="../../assets/fetch_push.gif" alt="fetch push" style="width: 24%; object-fit: contain;">
40+
<img src="../../assets/fetch_slide.gif" alt="fetch slide" style="width: 24%; object-fit: contain;">
41+
<img src="../../assets/fetch_pickandplace.gif" alt="fetch pick and place" style="width: 24%; object-fit: contain;">
42+
<img src="../../assets/fetch_reach.gif" alt="fetch reach" style="width: 24%; object-fit: contain;">
4343
</div>
4444

4545
An agent controls a 7-DoF Fetch robotic arm. The agent manipulates explicit Cartesian coordinates to move the gripper and actuate the fingers to interact with the environment.

stable_worldmodel/envs/gymnasium_robotics/fetch.py

Lines changed: 104 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -216,7 +216,7 @@ def reset(self, seed=None, options=None):
216216

217217
obs, info = super().reset(seed=seed, options=options)
218218

219-
self._apply_visual_variations()
219+
self._apply_visual_variations(active_variations)
220220

221221
changed_physics = False
222222
if any(
@@ -327,15 +327,40 @@ def _get_geoms_for_material(self, model, mat_name):
327327
return []
328328
return [i for i in range(model.ngeom) if model.geom_matid[i] == mat_id]
329329

330-
def _apply_visual_variations(self):
331-
"""Modifies the underlying MuJoCo model to apply visual variations."""
330+
def _apply_visual_variations(self, active_variations):
331+
"""Modifies the underlying MuJoCo model to apply visual variations.
332+
333+
Only variations whose key appears in ``active_variations`` are applied —
334+
anything else is left untouched so the original MJCF defaults (skybox,
335+
floor, table, etc.) are preserved on unmodified resets.
336+
"""
332337
if mujoco is None:
333338
return
334339

335340
model = self.env.unwrapped.model
336341
if model is None:
337342
return
338343

344+
active = set(active_variations)
345+
needs_table = 'table.color' in active
346+
needs_bg = 'background.color' in active
347+
needs_object = 'object.color' in active
348+
needs_light = 'light.intensity' in active
349+
needs_camera = 'camera.angle_delta' in active
350+
needs_arm = 'rendering.transparent_arm' in active
351+
352+
if not any(
353+
[
354+
needs_table,
355+
needs_bg,
356+
needs_object,
357+
needs_light,
358+
needs_camera,
359+
needs_arm,
360+
]
361+
):
362+
return
363+
339364
if not hasattr(self, '_table_geoms'):
340365
self._table_geoms = self._get_geoms_for_material(
341366
model, 'table_mat'
@@ -366,85 +391,87 @@ def _apply_visual_variations(self):
366391
):
367392
model.geom_matid[i] = -1
368393

369-
# Now we can safely modify geom_rgba
370-
table_color = self.variation_space['table']['color'].value
371-
for i in self._table_geoms:
372-
model.geom_rgba[i][:3] = table_color
373-
374-
bg_color = self.variation_space['background']['color'].value
375-
for i in self._floor_geoms:
376-
model.geom_rgba[i][:3] = bg_color
377-
378-
if getattr(self, '_skybox_tex_id', -1) >= 0:
379-
skybox_tex_id = self._skybox_tex_id
380-
bg_color_uint8 = (bg_color * 255).astype(np.uint8)
381-
start_idx = model.tex_adr[skybox_tex_id]
382-
channels = model.tex_nchannel[skybox_tex_id]
383-
num_pixels = (
384-
model.tex_width[skybox_tex_id]
385-
* model.tex_height[skybox_tex_id]
386-
)
387-
388-
if channels >= 3:
389-
view = model.tex_data[
390-
start_idx : start_idx + num_pixels * channels
391-
].reshape(-1, channels)
392-
view[:, :3] = bg_color_uint8[:3]
394+
if needs_table:
395+
table_color = self.variation_space['table']['color'].value
396+
for i in self._table_geoms:
397+
model.geom_rgba[i][:3] = table_color
398+
399+
if needs_bg:
400+
bg_color = self.variation_space['background']['color'].value
401+
for i in self._floor_geoms:
402+
model.geom_rgba[i][:3] = bg_color
403+
404+
if getattr(self, '_skybox_tex_id', -1) >= 0:
405+
skybox_tex_id = self._skybox_tex_id
406+
bg_color_uint8 = (bg_color * 255).astype(np.uint8)
407+
start_idx = model.tex_adr[skybox_tex_id]
408+
channels = model.tex_nchannel[skybox_tex_id]
409+
num_pixels = (
410+
model.tex_width[skybox_tex_id]
411+
* model.tex_height[skybox_tex_id]
412+
)
393413

394-
if hasattr(self.env, 'unwrapped') and hasattr(
395-
self.env.unwrapped, 'mujoco_renderer'
396-
):
397-
renderer = self.env.unwrapped.mujoco_renderer
398-
if renderer is not None and hasattr(renderer, 'viewer'):
399-
viewer = renderer.viewer
400-
if getattr(viewer, 'con', None) is not None:
401-
mujoco.mjr_uploadTexture(
402-
model, viewer.con, skybox_tex_id
403-
)
404-
405-
object_color = self.variation_space['object']['color'].value
406-
for i in self._object_geoms:
407-
model.geom_rgba[i][:3] = object_color
408-
409-
# Apply light intensity
410-
light_id = mujoco.mj_name2id(
411-
model, mujoco.mjtObj.mjOBJ_LIGHT, 'light0'
412-
)
413-
if light_id >= 0:
414-
intensity = self.variation_space['light']['intensity'].value[0]
415-
model.light_diffuse[light_id][:3] = np.array(
416-
[intensity, intensity, intensity]
414+
if channels >= 3:
415+
view = model.tex_data[
416+
start_idx : start_idx + num_pixels * channels
417+
].reshape(-1, channels)
418+
view[:, :3] = bg_color_uint8[:3]
419+
420+
if hasattr(self.env, 'unwrapped') and hasattr(
421+
self.env.unwrapped, 'mujoco_renderer'
422+
):
423+
renderer = self.env.unwrapped.mujoco_renderer
424+
if renderer is not None and hasattr(renderer, 'viewer'):
425+
viewer = renderer.viewer
426+
if getattr(viewer, 'con', None) is not None:
427+
mujoco.mjr_uploadTexture(
428+
model, viewer.con, skybox_tex_id
429+
)
430+
431+
if needs_object:
432+
object_color = self.variation_space['object']['color'].value
433+
for i in self._object_geoms:
434+
model.geom_rgba[i][:3] = object_color
435+
436+
if needs_light:
437+
light_id = mujoco.mj_name2id(
438+
model, mujoco.mjtObj.mjOBJ_LIGHT, 'light0'
417439
)
440+
if light_id >= 0:
441+
intensity = self.variation_space['light']['intensity'].value[0]
442+
model.light_diffuse[light_id][:3] = np.array(
443+
[intensity, intensity, intensity]
444+
)
418445

419-
# Apply camera angle perturbation (azimuth, elevation offsets in degrees)
420-
angle_delta = self.variation_space['camera']['angle_delta'].value[0]
421-
for cam_id in range(model.ncam):
422-
model.cam_pos[cam_id] = self._default_cam_pos[cam_id]
423-
model.cam_quat[cam_id] = self._default_cam_quat[cam_id]
446+
if needs_camera:
447+
angle_delta = self.variation_space['camera']['angle_delta'].value[
448+
0
449+
]
450+
for cam_id in range(model.ncam):
451+
model.cam_pos[cam_id] = self._default_cam_pos[cam_id]
452+
model.cam_quat[cam_id] = self._default_cam_quat[cam_id]
424453

425-
azimuth_rad = np.radians(angle_delta[0])
426-
elevation_rad = np.radians(angle_delta[1])
454+
azimuth_rad = np.radians(angle_delta[0])
455+
elevation_rad = np.radians(angle_delta[1])
427456

428-
# Rotate camera position around the look-at point using azimuth offset
429-
pos = model.cam_pos[cam_id].copy()
430-
cos_az, sin_az = np.cos(azimuth_rad), np.sin(azimuth_rad)
431-
x, y = pos[0], pos[1]
432-
pos[0] = x * cos_az - y * sin_az
433-
pos[1] = x * sin_az + y * cos_az
457+
pos = model.cam_pos[cam_id].copy()
458+
cos_az, sin_az = np.cos(azimuth_rad), np.sin(azimuth_rad)
459+
x, y = pos[0], pos[1]
460+
pos[0] = x * cos_az - y * sin_az
461+
pos[1] = x * sin_az + y * cos_az
434462

435-
# Apply elevation offset by tilting Z
436-
cos_el, sin_el = np.cos(elevation_rad), np.sin(elevation_rad)
437-
z, r = pos[2], np.sqrt(pos[0] ** 2 + pos[1] ** 2)
438-
pos[2] = z * cos_el + r * sin_el
463+
cos_el, sin_el = np.cos(elevation_rad), np.sin(elevation_rad)
464+
z, r = pos[2], np.sqrt(pos[0] ** 2 + pos[1] ** 2)
465+
pos[2] = z * cos_el + r * sin_el
439466

440-
model.cam_pos[cam_id] = pos
467+
model.cam_pos[cam_id] = pos
441468

442-
# Optional: Manipulate the alpha (transparency) channels of the entire robotic arm
443-
is_transparent = (
444-
self.variation_space['rendering']['transparent_arm'].value == 1
445-
)
446-
alpha_val = 0.3 if is_transparent else 1.0
447-
for i in range(model.ngeom):
448-
name = mujoco.mj_id2name(model, mujoco.mjtObj.mjOBJ_GEOM, i)
449-
if name and 'robot0:' in name:
450-
model.geom_rgba[i][3] = alpha_val
469+
if needs_arm:
470+
is_transparent = (
471+
self.variation_space['rendering']['transparent_arm'].value == 1
472+
)
473+
alpha_val = 0.3 if is_transparent else 1.0
474+
for i in range(model.ngeom):
475+
name = mujoco.mj_id2name(model, mujoco.mjtObj.mjOBJ_GEOM, i)
476+
if name and 'robot0:' in name:
477+
model.geom_rgba[i][3] = alpha_val

0 commit comments

Comments
 (0)