Source code for causal_world.wrappers.protocol_wrapper

import gym


[docs]class ProtocolWrapper(gym.Wrapper):
[docs] def __init__(self, env, protocol): """ :param env: (causal_world.CausalWorld) the environment to convert. :param protocol: (causal_world.evaluation.ProtocolBase) protocol to evaluate. """ super(ProtocolWrapper, self).__init__(env) self.protocol = protocol self.env.add_wrapper_info( {'evaluation_environment': self.protocol.get_name()}) self._elapsed_episodes = 0 self._elapsed_timesteps = 0 return
[docs] def step(self, action): """ Used to step through the enviroment. :param action: (nd.array) specifies which action should be taken by the robot, should follow the same action mode specified. :return: (nd.array) specifies the observations returned after stepping through the environment. Again, it follows the observation_mode specified. """ observation, reward, done, info = self.env.step(action) self._elapsed_timesteps += 1 invalid_interventions = 0 interventions_dict = self.protocol.get_intervention( episode=self._elapsed_episodes, timestep=self._elapsed_episodes) if interventions_dict is not None: success_signal, observation = \ self.env.do_intervention(interventions_dict=interventions_dict) while not success_signal and invalid_interventions < 5: invalid_interventions += 1 interventions_dict = self.protocol.get_intervention( episode=self._elapsed_episodes, timestep=self._elapsed_episodes) if interventions_dict is not None: success_signal, observation = \ self.env.do_intervention(interventions_dict= interventions_dict) else: break return observation, reward, done, info
[docs] def reset(self): """ Resets the environment to the current starting state of the environment. :return: (nd.array) specifies the observations returned after resetting the environment. Again, it follows the observation_mode specified. """ self._elapsed_episodes += 1 self._elapsed_timesteps = 0 invalid_interventions = 0 observation = self.env.reset() interventions_dict = self.protocol.get_intervention( episode=self._elapsed_episodes, timestep=0) if interventions_dict is not None: success_signal, observation = self.env.do_intervention(interventions_dict) while not success_signal and invalid_interventions < 5: invalid_interventions += 1 interventions_dict = self.protocol.get_intervention( episode=self._elapsed_episodes, timestep=0) if interventions_dict is not None: success_signal, observation = self.env.do_intervention( interventions_dict) else: break return observation