Source code for causal_world.envs.scene.stage

from causal_world.envs.scene.observations import StageObservations
from causal_world.envs.scene.objects import Cuboid, StaticCuboid, MeshObject
from causal_world.envs.scene.silhouette import SCuboid, SSphere, SMeshObject
from causal_world.utils.state_utils import get_intersection, get_bounding_box_volume
import math
import numpy as np
import pybullet
from causal_world.configs.world_constants import WorldConstants
from collections import OrderedDict


[docs]class Stage(object):
[docs] def __init__(self, observation_mode, normalize_observations, pybullet_client_full_id, pybullet_client_w_goal_id, pybullet_client_w_o_goal_id, cameras, camera_indicies): """ This class represents the stage object, where it handles all the arena functionalities including the objects and silhouettes existing in the arena. :param observation_mode: (str) should be "structured" or "pixel" :param normalize_observations: (bool) to normalize the observations or not. :param pybullet_client_full_id: (int) pybullet client if visualization is enabled. :param pybullet_client_w_goal_id: (int) pybullet client with the goal in the image without tool objects. :param pybullet_client_w_o_goal_id: (int) pybullet client without the goal, only tool blocks. :param cameras: (list) list of causal_world.robot.Camera object specifying the cameras mounted on top of the trifinger robot. :param camera_indicies: (list) list of integers of the order of cameras to be used. """ self._rigid_objects = OrderedDict() self._visual_objects = OrderedDict() self._observation_mode = observation_mode self._pybullet_client_full_id = pybullet_client_full_id self._pybullet_client_w_goal_id = pybullet_client_w_goal_id self._pybullet_client_w_o_goal_id = pybullet_client_w_o_goal_id self._camera_indicies = camera_indicies self._normalize_observations = normalize_observations self._stage_observations = None self._name_keys = [] self._default_gravity = [0, 0, -9.81] self._current_gravity = np.array(self._default_gravity) self._visual_object_client_instances = [] self._rigid_objects_client_instances = [] if self._pybullet_client_full_id is not None: self._visual_object_client_instances.append( self._pybullet_client_full_id) self._rigid_objects_client_instances.append( self._pybullet_client_full_id) if self._pybullet_client_w_o_goal_id is not None: self._rigid_objects_client_instances.append( self._pybullet_client_w_o_goal_id) if self._pybullet_client_w_goal_id is not None: self._visual_object_client_instances.append( self._pybullet_client_w_goal_id) self._cameras = cameras self._goal_image = None return
[docs] def get_floor_height(self): """ :return: (float) returns the floor height. """ return WorldConstants.FLOOR_HEIGHT
[docs] def get_arena_bb(self): """ :return: (list) list of the lower bound (x, y, z) and the upper bound (x, y, z) so (2, 3) shape. """ return WorldConstants.ARENA_BB
[docs] def get_rigid_objects(self): """ :return: (dict) returns the rigid objects in the arena currently. """ return self._rigid_objects
[docs] def get_visual_objects(self): """ :return: (dict) returns the visual objects in the arena currently. """ return self._visual_objects
[docs] def get_full_env_state(self): """ :return: (dict) returns a dict specifying everything about the current state of the arena. """ env_state = {} env_state['rigid_objects'] = [] for rigid_object_key in self._rigid_objects: if isinstance(self._rigid_objects[rigid_object_key], Cuboid): env_state['rigid_objects'].append([ 'cube', self._rigid_objects[rigid_object_key].get_recreation_params( ) ]) if isinstance(self._rigid_objects[rigid_object_key], StaticCuboid): env_state['rigid_objects'].append([ 'static_cube', self._rigid_objects[rigid_object_key].get_recreation_params( ) ]) if isinstance(self._rigid_objects[rigid_object_key], MeshObject): env_state['rigid_objects'].append([ 'mesh', self._rigid_objects[rigid_object_key].get_recreation_params( ) ]) env_state['visual_objects'] = [] for visual_object_key in self._visual_objects: if isinstance(self._visual_objects[visual_object_key], SCuboid): env_state['visual_objects'].append([ 'cube', self._visual_objects[visual_object_key]. get_recreation_params() ]) if isinstance(self._visual_objects[visual_object_key], SSphere): env_state['visual_objects'].append([ 'sphere', self._visual_objects[visual_object_key]. get_recreation_params() ]) if isinstance(self._visual_objects[visual_object_key], SMeshObject): env_state['visual_objects'].append([ 'mesh', self._visual_objects[visual_object_key]. get_recreation_params() ]) env_state['arena_variable_values'] = \ self.get_current_variable_values_for_arena() return env_state
[docs] def set_full_env_state(self, env_state): """ :param env_state: (dict) dict specifying everything about the current state of the arena, usually obtained through get_full_env_state function. :return: """ self.remove_everything() for rigid_object_info in env_state['rigid_objects']: if rigid_object_info[0] == 'mesh': self.add_rigid_mesh_object(**rigid_object_info[1]) else: self.add_rigid_general_object(shape=rigid_object_info[0], **rigid_object_info[1]) for visual_object_info in env_state['visual_objects']: if visual_object_info[0] == 'mesh': self.add_silhoutte_mesh_object(**visual_object_info[1]) else: self.add_silhoutte_general_object(shape=visual_object_info[0], **visual_object_info[1]) self.apply_interventions(env_state['arena_variable_values']) self._stage_observations.rigid_objects = self._rigid_objects self._stage_observations.visual_objects = self._visual_objects return
[docs] def add_rigid_general_object(self, name, shape, **object_params): """ :param name: (str) a str specifying a unique name of the rigid object. :param shape: (str) specifying "cube" or "static_cube" for now. :param object_params: (params) depends on the parameters used for constructing the corresponding object. :return: """ if name in self._name_keys: raise Exception("name already exists as key for scene objects") else: self._name_keys.append(name) if shape == "cube": self._rigid_objects[name] = Cuboid( self._rigid_objects_client_instances, name, **object_params) elif shape == "static_cube": self._rigid_objects[name] = StaticCuboid( self._rigid_objects_client_instances, name, **object_params) else: raise Exception("shape is not yet implemented") return
[docs] def remove_general_object(self, name): """ :param name: (str) a str specifying a unique name of the object to remove from the arena. :return: """ if name not in self._name_keys: raise Exception("name does not exists as key for scene objects") else: self._name_keys.remove(name) if name in self._rigid_objects.keys(): self._rigid_objects[name].remove() del self._rigid_objects[name] elif name in self._visual_objects.keys(): self._visual_objects[name].remove() del self._visual_objects[name] return
[docs] def remove_everything(self): """ removes all the objects and visuals from the arena. :return: """ current_objects = list(self._rigid_objects.keys()) + \ list(self._visual_objects.keys()) current_objects = current_objects[::-1] for name in current_objects: self.remove_general_object(name) return
[docs] def add_rigid_mesh_object(self, name, filename, **object_params): """ :param name: (str) a str specifying a unique name of the mesh object. :param filename: (str) a str specifying the location of the .obj file. :param object_params: (params) depends on the parameters used for constructing the corresponding object. :return: """ if name in self._name_keys: raise Exception("name already exists as key for scene objects") else: self._name_keys.append(name) self._rigid_objects[name] = MeshObject( self._rigid_objects_client_instances, name, filename, **object_params) return
[docs] def add_silhoutte_general_object(self, name, shape, **object_params): """ :param name: (str) specifying a unique name of the visual object. :param shape: (str) specifying "cube" or "sphere" for now. :param object_params: (params) depends on the parameters used for constructing the corresponding object. :return: """ if name in self._name_keys: raise Exception("name already exists as key for scene objects") else: self._name_keys.append(name) if shape == "cube": self._visual_objects[name] = SCuboid( self._visual_object_client_instances, name, **object_params) elif shape == "sphere": self._visual_objects[name] = SSphere( self._visual_object_client_instances, name, **object_params) else: raise Exception("shape is not implemented yet") return
[docs] def add_silhoutte_mesh_object(self, name, filename, **object_params): """ :param name: (str) specifying a unique name of the mesh visual object. :param filename: (str) a str specifying the location of the .obj file. :param object_params: (params) depends on the parameters used for constructing the corresponding object. :return: """ if name in self._name_keys: raise Exception("name already exists as key for scene objects") else: self._name_keys.append(name) self._visual_objects[name] = SMeshObject( self._visual_object_client_instances, name, filename, **object_params) return
[docs] def finalize_stage(self): """ finalizes the observation space of the environment after adding all the objects and visuals in the stage. :return: """ if self._observation_mode == "pixel": self._stage_observations = StageObservations( self._rigid_objects, self._visual_objects, self._observation_mode, self._normalize_observations, cameras=self._cameras, camera_indicies=self._camera_indicies) self.update_goal_image() else: self._stage_observations = StageObservations( self._rigid_objects, self._visual_objects, self._observation_mode, self._normalize_observations) return
[docs] def select_observations(self, observation_keys): """ selects the observations to be returned by the environment. :param observation_keys: (list) list of str that specifies the observation keys to be returned and calculated by the environment. :return: """ self._stage_observations.reset_observation_keys() self._stage_observations.initialize_observations() for key in observation_keys: self._stage_observations.add_observation(key) self._stage_observations.set_observation_spaces()
[docs] def get_full_state(self, state_type='list'): """ :param state_type: (str) 'list' or 'dict' specifying to return the state as a dict with the state name as a key or just a concatenated list. :return: (list or dict) depending on the arg state_type, returns the full state of the stage itself. """ if state_type == 'list': stage_state = [] elif state_type == 'dict': stage_state = dict() else: raise Exception("type is not supported") for name in self._name_keys: if name in self._rigid_objects: object = self._rigid_objects[name] elif name in self._visual_objects: object = self._visual_objects[name] else: raise Exception("possible error here") if state_type == 'list': stage_state.extend(object.get_state(state_type='list')) elif state_type == 'dict': stage_state[name] = object.get_state(state_type='dict') return stage_state
[docs] def set_full_state(self, new_state): """ :param new_state: (dict) specifies the full state of all the objects in the arena. :return: """ #TODO: under the assumption that the new state has the same number of objects start = 0 for name in self._name_keys: if name in self._rigid_objects: object = self._rigid_objects[name] end = start + object.get_state_size() object.set_full_state(new_state[start:end]) elif name in self._visual_objects: object = self._visual_objects[name] end = start + object.get_state_size() object.set_full_state(new_state[start:end]) start = end if self._observation_mode == "pixel": self.update_goal_image() return
[docs] def set_objects_pose(self, names, positions, orientations): """ :param names: (list) list of object names to set their positions and orientations. :param positions: (list) corresponding list of positions of objects to be set. :param orientations: (list) corresponding list of orientations of objects to be set. :return: """ for i in range(len(names)): name = names[i] if name in self._rigid_objects: object = self._rigid_objects[name] object.set_pose(positions[i], orientations[i]) elif name in self._visual_objects: object = self._visual_objects[name] object.set_pose(positions[i], orientations[i]) else: raise Exception("Object {} doesnt exist".format(name)) if self._observation_mode == "pixel": self.update_goal_image() return
[docs] def get_current_observations(self, helper_keys): """ :param helper_keys: (list) list of observation keys that are not part of the observation space but still needed to be returned to calculate further observations or for a dense reward function calculation. :return: (dict) returns the current observations where the keys corresponds to the observation key. """ return self._stage_observations.get_current_observations(helper_keys)
[docs] def get_observation_spaces(self): """ :return: (gym.spaces.Box) returns the current observation space of the environment. """ return self._stage_observations.get_observation_spaces()
[docs] def random_position(self, height_limits=(0.05, 0.15), angle_limits=(-2 * math.pi, 2 * math.pi), radius_limits=(0.0, 0.15), allowed_section=np.array([[-0.5, -0.5, 0], [0.5, 0.5, 0.5]])): """ :param height_limits: (tuple) tuple of two values for low bound and upper bound. :param angle_limits: (tuple) tuple of two values for low bound and upper bound, theta in polar coordinates. :param radius_limits: (tuple) tuple of two values for low bound and upper bound, radius in polar coordinates. :param allowed_section: (nd.array) array of two sublists for low bound and upper bound (x, y, z) of restricted area for sampling, the shape of the input is basically (2,2). :return: (list) returns a cartesian random position in the arena (x, y, z). """ satisfying_constraints = False while not satisfying_constraints: angle = np.random.uniform(*angle_limits) # for uniform sampling with respect to the disc area use scaling radial_distance = np.sqrt( np.random.uniform(radius_limits[0]**2, radius_limits[1]**2)) if isinstance(height_limits, (int, float)): height_z = height_limits else: height_z = np.random.uniform(*height_limits) object_position = [ radial_distance * math.cos(angle), radial_distance * math.sin(angle), height_z, ] #check if satisfying_constraints if np.all(object_position > allowed_section[0]) and \ np.all(object_position < allowed_section[1]): satisfying_constraints = True return object_position
[docs] def get_current_object_keys(self): """ :return: (list) returns the names of the rigid objects and visual objects concatenated. """ return list(self._rigid_objects.keys()) + \ list(self._visual_objects.keys())
[docs] def object_intervention(self, key, interventions_dict): """ :param key: (str) the unique name of the rigid or visual object to intervene on. :param interventions_dict: (dict) dict specifying the intervention to be performed. :return: """ if key in self._rigid_objects: object = self._rigid_objects[key] elif key in self._visual_objects: object = self._visual_objects[key] else: raise Exception( "The key {} passed doesn't exist in the stage yet".format(key)) object.apply_interventions(interventions_dict) if self._observation_mode == "pixel": self.update_goal_image() return
[docs] def get_current_variable_values_for_arena(self): """ :return: (dict) returns all the exposed variables and their values in the environment's stage except for objects. """ if self._pybullet_client_w_o_goal_id is not None: client = self._pybullet_client_w_o_goal_id is not None else: client = self._pybullet_client_full_id variable_params = dict() variable_params["floor_color"] = \ pybullet.getVisualShapeData(WorldConstants.FLOOR_ID, physicsClientId=client)[0][7][:3] variable_params["floor_friction"] = \ pybullet.getDynamicsInfo(WorldConstants.FLOOR_ID, -1, physicsClientId=client)[1] variable_params["stage_color"] = \ pybullet.getVisualShapeData(WorldConstants.STAGE_ID, physicsClientId=client)[0][7][:3] variable_params["stage_friction"] = \ pybullet.getDynamicsInfo(WorldConstants.STAGE_ID, -1, physicsClientId=client)[1] variable_params["gravity"] = \ self._current_gravity return variable_params
[docs] def get_current_variable_values_for_objects(self): """ :return: (dict) returns all the exposed variables and their values in the environment's stage for objects only. """ return self.get_full_state(state_type='dict')
[docs] def get_current_variable_values(self): """ :return: (dict) returns all the exposed variables and their values in the environment's stage. """ variable_params = self.get_current_variable_values_for_arena() variable_params.update(self.get_current_variable_values_for_objects()) return variable_params
[docs] def apply_interventions(self, interventions_dict): """ :param interventions_dict: (dict) dict specifying the intervention to be performed. :return: """ for intervention in interventions_dict: if isinstance(interventions_dict[intervention], dict): self.object_intervention(intervention, interventions_dict[intervention]) elif intervention == "floor_color": for client in self._visual_object_client_instances: pybullet.changeVisualShape( WorldConstants.FLOOR_ID, -1, rgbaColor=np.append(interventions_dict[intervention], 1), physicsClientId=client) for client in self._rigid_objects_client_instances: pybullet.changeVisualShape( WorldConstants.FLOOR_ID, -1, rgbaColor=np.append(interventions_dict[intervention], 1), physicsClientId=client) elif intervention == "stage_color": for client in self._visual_object_client_instances: pybullet.changeVisualShape( WorldConstants.STAGE_ID, -1, rgbaColor=np.append(interventions_dict[intervention], 1), physicsClientId=client) for client in self._rigid_objects_client_instances: pybullet.changeVisualShape( WorldConstants.STAGE_ID, -1, rgbaColor=np.append(interventions_dict[intervention], 1), physicsClientId=client) elif intervention == "stage_friction": for client in self._rigid_objects_client_instances: pybullet.changeDynamics( bodyUniqueId=WorldConstants.STAGE_ID, linkIndex=-1, lateralFriction=interventions_dict[intervention], physicsClientId=client) elif intervention == "floor_friction": for client in self._rigid_objects_client_instances: pybullet.changeDynamics( bodyUniqueId=WorldConstants.FLOOR_ID, linkIndex=-1, lateralFriction=interventions_dict[intervention], physicsClientId=client) elif intervention == "gravity": for client in self._rigid_objects_client_instances: pybullet.setGravity(interventions_dict[intervention][0], interventions_dict[intervention][1], interventions_dict[intervention][2], physicsClientId=client) self._current_gravity = interventions_dict[intervention] else: raise Exception("The intervention on stage " "is not supported yet") if self._observation_mode == "pixel": self.update_goal_image() return
[docs] def get_object_full_state(self, key): """ :param key: (str) specifying the name of the object to return its state. :return: (dict) specifies the state of the object queried. """ if key in self._rigid_objects: return self._rigid_objects[key].get_state('dict') elif key in self._visual_objects: return self._visual_objects[key].get_state('dict') else: raise Exception( "The key {} passed doesn't exist in the stage yet".format(key))
[docs] def get_object_state(self, key, state_variable): """ :param key: (str) specifying the name of the object to return its state's variable value. :param state_variable: (str) specifying the variable's name of the object. :return: (nd.array) returns the variable's value of the object queried. """ if key in self._rigid_objects: return np.array( self._rigid_objects[key].get_variable_state(state_variable)) elif key in self._visual_objects: return np.array( self._visual_objects[key].get_variable_state(state_variable)) else: raise Exception( "The key {} passed doesn't exist in the stage yet".format(key))
[docs] def get_object(self, key): """ :param key: (str) specifying the name of the object to return it. :return: (causal_world.RigidObject or causal_world.SilhouetteObject) object to return. """ if key in self._rigid_objects: return self._rigid_objects[key] elif key in self._visual_objects: return self._visual_objects[key] else: raise Exception( "The key {} passed doesn't exist in the stage yet".format(key))
[docs] def are_blocks_colliding(self, block1, block2): """ :param block1: (causal_world.RigidObject) first block. :param block2: (causal_world.RigidObject) second block. :return: (bool) true if the two blocks passed are colliding. """ for contact in pybullet.getContactPoints( physicsClientId=self._rigid_objects_client_instances[0]): if (contact[1] == block1._block_ids[0] and contact[2] == block2._block_ids[0]) or \ (contact[2] == block1._block_ids[0] and contact[1] == block2._block_ids[0]): return True return False
[docs] def check_stage_free_of_colliding_blocks(self): """ :return: (bool) true if the stage is free of collisions of blocks. """ for contact in pybullet.getContactPoints( physicsClientId=self._rigid_objects_client_instances[0]): if contact[1] > 3 and contact[2] > 3: return False return True
[docs] def is_colliding_with_stage(self, block1): """ :param block1: (causal_world.RigidObject) first block. :return: (bool) true if the stage is free of collisions between blocks. """ for contact in pybullet.getContactPoints( physicsClientId=self._rigid_objects_client_instances[0]): if (contact[1] == block1._block_ids[0] and contact[2] == WorldConstants.STAGE_ID) or \ (contact[2] == block1._block_ids[0] and contact[1] == WorldConstants.STAGE_ID): return True return False
[docs] def is_colliding_with_floor(self, block1): """ :param block1: (causal_world.RigidObject) first block. :return: (bool) true if the block is colliding with the floor. """ for contact in pybullet.getContactPoints( physicsClientId=self._rigid_objects_client_instances[0]): if (contact[1] == block1._block_ids[0] and contact[2] == WorldConstants.FLOOR_ID) or \ (contact[2] == block1._block_ids[0] and contact[1] == WorldConstants.FLOOR_ID): return True return False
[docs] def get_normal_interaction_force_between_blocks(self, block1, block2): """ :param block1: (causal_world.RigidObject) first block. :param block2: (causal_world.RigidObject) second block. :return: (float) normal interaction force between blocks or None if no interaction. """ for contact in pybullet.getContactPoints( physicsClientId=self._rigid_objects_client_instances[0]): if (contact[1] == block1._block_ids[0] and contact[2] == block2._block_ids[0]) or \ (contact[2] == block1._block_ids[0] and contact[1] == block2._block_ids[0]): return contact[9] * np.array(contact[7]) return None
[docs] def add_observation(self, observation_key, lower_bound=None, upper_bound=None): """ :param observation_key: (str) new observation key to be added. :param lower_bound: (nd.array) low bound of the observation. :param upper_bound: (nd.array) upper bound of the observation. :return: """ self._stage_observations.add_observation(observation_key, lower_bound, upper_bound) return
[docs] def normalize_observation_for_key(self, observation, key): """ :param observation: (nd.array) observation to normalize. :param key: (str) observation key to be normalized. :return: (nd.array) normalized observation. """ return self._stage_observations.normalize_observation_for_key(observation, key)
[docs] def denormalize_observation_for_key(self, observation, key): """ :param observation: (nd.array) observation to denormalize. :param key: (str) observation key to be denormalized. :return: (nd.array) denormalized observation. """ return self._stage_observations.denormalize_observation_for_key(observation, key)
[docs] def get_current_goal_image(self): """ :return: (nd.array) returns the goal images concatenated if 'pixel' mode is enabled. """ return self._goal_image
[docs] def update_goal_image(self): """ updated the goal image. :return: """ self._goal_image = self._stage_observations.get_current_goal_image() return
[docs] def check_feasiblity_of_stage(self): """ This function checks the feasibility of the current state of the stage (i.e checks if any of the bodies in the simulation are in a penetration mode) :return: (bool) A boolean indicating whether the stage is in a collision state or not. As well as if the visual objects are outside of the bounding box or not. """ for contact in pybullet.getContactPoints( physicsClientId=self._rigid_objects_client_instances[0]): if contact[8] < -0.03: return False #check if all the visual objects are within the bb og the available arena for visual_object in self._visual_objects: if get_intersection(self._visual_objects[visual_object]. get_bounding_box(), self.get_stage_bb())/\ self._visual_objects[visual_object].get_volume() < 0.50: return False if self._visual_objects[visual_object].get_bounding_box()[0][-1] < -0.01: return False return True
[docs] def get_stage_bb(self): """ :return: (tuple) first element indicates the lower bound the stage arena (x,y, z) and second element indicated the upper bound similarly. """ return (tuple(WorldConstants.ARENA_BB[0]), tuple(WorldConstants.ARENA_BB[1]))