@@ -216,7 +216,7 @@ def reset(self, seed=None, options=None):
216216
217217 obs , info = super ().reset (seed = seed , options = options )
218218
219- self ._apply_visual_variations ()
219+ self ._apply_visual_variations (active_variations )
220220
221221 changed_physics = False
222222 if any (
@@ -327,15 +327,40 @@ def _get_geoms_for_material(self, model, mat_name):
327327 return []
328328 return [i for i in range (model .ngeom ) if model .geom_matid [i ] == mat_id ]
329329
330- def _apply_visual_variations (self ):
331- """Modifies the underlying MuJoCo model to apply visual variations."""
330+ def _apply_visual_variations (self , active_variations ):
331+ """Modifies the underlying MuJoCo model to apply visual variations.
332+
333+ Only variations whose key appears in ``active_variations`` are applied —
334+ anything else is left untouched so the original MJCF defaults (skybox,
335+ floor, table, etc.) are preserved on unmodified resets.
336+ """
332337 if mujoco is None :
333338 return
334339
335340 model = self .env .unwrapped .model
336341 if model is None :
337342 return
338343
344+ active = set (active_variations )
345+ needs_table = 'table.color' in active
346+ needs_bg = 'background.color' in active
347+ needs_object = 'object.color' in active
348+ needs_light = 'light.intensity' in active
349+ needs_camera = 'camera.angle_delta' in active
350+ needs_arm = 'rendering.transparent_arm' in active
351+
352+ if not any (
353+ [
354+ needs_table ,
355+ needs_bg ,
356+ needs_object ,
357+ needs_light ,
358+ needs_camera ,
359+ needs_arm ,
360+ ]
361+ ):
362+ return
363+
339364 if not hasattr (self , '_table_geoms' ):
340365 self ._table_geoms = self ._get_geoms_for_material (
341366 model , 'table_mat'
@@ -366,85 +391,87 @@ def _apply_visual_variations(self):
366391 ):
367392 model .geom_matid [i ] = - 1
368393
369- # Now we can safely modify geom_rgba
370- table_color = self .variation_space ['table' ]['color' ].value
371- for i in self ._table_geoms :
372- model .geom_rgba [i ][:3 ] = table_color
373-
374- bg_color = self .variation_space ['background' ]['color' ].value
375- for i in self ._floor_geoms :
376- model .geom_rgba [i ][:3 ] = bg_color
377-
378- if getattr (self , '_skybox_tex_id' , - 1 ) >= 0 :
379- skybox_tex_id = self ._skybox_tex_id
380- bg_color_uint8 = (bg_color * 255 ).astype (np .uint8 )
381- start_idx = model .tex_adr [skybox_tex_id ]
382- channels = model .tex_nchannel [skybox_tex_id ]
383- num_pixels = (
384- model .tex_width [skybox_tex_id ]
385- * model .tex_height [skybox_tex_id ]
386- )
387-
388- if channels >= 3 :
389- view = model .tex_data [
390- start_idx : start_idx + num_pixels * channels
391- ].reshape (- 1 , channels )
392- view [:, :3 ] = bg_color_uint8 [:3 ]
394+ if needs_table :
395+ table_color = self .variation_space ['table' ]['color' ].value
396+ for i in self ._table_geoms :
397+ model .geom_rgba [i ][:3 ] = table_color
398+
399+ if needs_bg :
400+ bg_color = self .variation_space ['background' ]['color' ].value
401+ for i in self ._floor_geoms :
402+ model .geom_rgba [i ][:3 ] = bg_color
403+
404+ if getattr (self , '_skybox_tex_id' , - 1 ) >= 0 :
405+ skybox_tex_id = self ._skybox_tex_id
406+ bg_color_uint8 = (bg_color * 255 ).astype (np .uint8 )
407+ start_idx = model .tex_adr [skybox_tex_id ]
408+ channels = model .tex_nchannel [skybox_tex_id ]
409+ num_pixels = (
410+ model .tex_width [skybox_tex_id ]
411+ * model .tex_height [skybox_tex_id ]
412+ )
393413
394- if hasattr (self .env , 'unwrapped' ) and hasattr (
395- self .env .unwrapped , 'mujoco_renderer'
396- ):
397- renderer = self .env .unwrapped .mujoco_renderer
398- if renderer is not None and hasattr (renderer , 'viewer' ):
399- viewer = renderer .viewer
400- if getattr (viewer , 'con' , None ) is not None :
401- mujoco .mjr_uploadTexture (
402- model , viewer .con , skybox_tex_id
403- )
404-
405- object_color = self .variation_space ['object' ]['color' ].value
406- for i in self ._object_geoms :
407- model .geom_rgba [i ][:3 ] = object_color
408-
409- # Apply light intensity
410- light_id = mujoco .mj_name2id (
411- model , mujoco .mjtObj .mjOBJ_LIGHT , 'light0'
412- )
413- if light_id >= 0 :
414- intensity = self .variation_space ['light' ]['intensity' ].value [0 ]
415- model .light_diffuse [light_id ][:3 ] = np .array (
416- [intensity , intensity , intensity ]
414+ if channels >= 3 :
415+ view = model .tex_data [
416+ start_idx : start_idx + num_pixels * channels
417+ ].reshape (- 1 , channels )
418+ view [:, :3 ] = bg_color_uint8 [:3 ]
419+
420+ if hasattr (self .env , 'unwrapped' ) and hasattr (
421+ self .env .unwrapped , 'mujoco_renderer'
422+ ):
423+ renderer = self .env .unwrapped .mujoco_renderer
424+ if renderer is not None and hasattr (renderer , 'viewer' ):
425+ viewer = renderer .viewer
426+ if getattr (viewer , 'con' , None ) is not None :
427+ mujoco .mjr_uploadTexture (
428+ model , viewer .con , skybox_tex_id
429+ )
430+
431+ if needs_object :
432+ object_color = self .variation_space ['object' ]['color' ].value
433+ for i in self ._object_geoms :
434+ model .geom_rgba [i ][:3 ] = object_color
435+
436+ if needs_light :
437+ light_id = mujoco .mj_name2id (
438+ model , mujoco .mjtObj .mjOBJ_LIGHT , 'light0'
417439 )
440+ if light_id >= 0 :
441+ intensity = self .variation_space ['light' ]['intensity' ].value [0 ]
442+ model .light_diffuse [light_id ][:3 ] = np .array (
443+ [intensity , intensity , intensity ]
444+ )
418445
419- # Apply camera angle perturbation (azimuth, elevation offsets in degrees)
420- angle_delta = self .variation_space ['camera' ]['angle_delta' ].value [0 ]
421- for cam_id in range (model .ncam ):
422- model .cam_pos [cam_id ] = self ._default_cam_pos [cam_id ]
423- model .cam_quat [cam_id ] = self ._default_cam_quat [cam_id ]
446+ if needs_camera :
447+ angle_delta = self .variation_space ['camera' ]['angle_delta' ].value [
448+ 0
449+ ]
450+ for cam_id in range (model .ncam ):
451+ model .cam_pos [cam_id ] = self ._default_cam_pos [cam_id ]
452+ model .cam_quat [cam_id ] = self ._default_cam_quat [cam_id ]
424453
425- azimuth_rad = np .radians (angle_delta [0 ])
426- elevation_rad = np .radians (angle_delta [1 ])
454+ azimuth_rad = np .radians (angle_delta [0 ])
455+ elevation_rad = np .radians (angle_delta [1 ])
427456
428- # Rotate camera position around the look-at point using azimuth offset
429- pos = model .cam_pos [cam_id ].copy ()
430- cos_az , sin_az = np .cos (azimuth_rad ), np .sin (azimuth_rad )
431- x , y = pos [0 ], pos [1 ]
432- pos [0 ] = x * cos_az - y * sin_az
433- pos [1 ] = x * sin_az + y * cos_az
457+ pos = model .cam_pos [cam_id ].copy ()
458+ cos_az , sin_az = np .cos (azimuth_rad ), np .sin (azimuth_rad )
459+ x , y = pos [0 ], pos [1 ]
460+ pos [0 ] = x * cos_az - y * sin_az
461+ pos [1 ] = x * sin_az + y * cos_az
434462
435- # Apply elevation offset by tilting Z
436- cos_el , sin_el = np .cos (elevation_rad ), np .sin (elevation_rad )
437- z , r = pos [2 ], np .sqrt (pos [0 ] ** 2 + pos [1 ] ** 2 )
438- pos [2 ] = z * cos_el + r * sin_el
463+ cos_el , sin_el = np .cos (elevation_rad ), np .sin (elevation_rad )
464+ z , r = pos [2 ], np .sqrt (pos [0 ] ** 2 + pos [1 ] ** 2 )
465+ pos [2 ] = z * cos_el + r * sin_el
439466
440- model .cam_pos [cam_id ] = pos
467+ model .cam_pos [cam_id ] = pos
441468
442- # Optional: Manipulate the alpha (transparency) channels of the entire robotic arm
443- is_transparent = (
444- self .variation_space ['rendering' ]['transparent_arm' ].value == 1
445- )
446- alpha_val = 0.3 if is_transparent else 1.0
447- for i in range (model .ngeom ):
448- name = mujoco .mj_id2name (model , mujoco .mjtObj .mjOBJ_GEOM , i )
449- if name and 'robot0:' in name :
450- model .geom_rgba [i ][3 ] = alpha_val
469+ if needs_arm :
470+ is_transparent = (
471+ self .variation_space ['rendering' ]['transparent_arm' ].value == 1
472+ )
473+ alpha_val = 0.3 if is_transparent else 1.0
474+ for i in range (model .ngeom ):
475+ name = mujoco .mj_id2name (model , mujoco .mjtObj .mjOBJ_GEOM , i )
476+ if name and 'robot0:' in name :
477+ model .geom_rgba [i ][3 ] = alpha_val
0 commit comments