Source code for causal_world.evaluation.protocols.fully_random_protocol

from causal_world.evaluation.protocols.protocol import ProtocolBase
import numpy as np


[docs]class FullyRandomProtocol(ProtocolBase):
[docs] def __init__(self, name, variable_space='space_a_b'): """ This specifies a fully random protocol, where an intervention is produced on every exposed variable by uniformly sampling the intervention space. :param name: (str) specifies the name of the protocol to be reported. :param variable_space: (str) "space_a", "space_b" or "space_a_b". """ super().__init__(name) self._variable_space = variable_space
[docs] def get_intervention(self, episode, timestep): """ Returns the interventions that are applied at a given timestep of the episode. :param episode: (int) episode number of the protocol :param timestep: (int) time step within episode :return: (dict) intervention dictionary """ if timestep == 0: if self._variable_space == 'space_a_b': intervention_space = self.env.get_intervention_space_a_b() elif self._variable_space == 'space_a': intervention_space = self.env.get_intervention_space_a() elif self._variable_space == 'space_b': intervention_space = self.env.get_intervention_space_b() interventions_dict = dict() intervene_on_size = np.random.choice([0, 1], p=[0.5, 0.5]) intervene_on_joint_positions = np.random.choice([0, 1], p=[1, 0]) for variable in intervention_space: if isinstance(intervention_space[variable], dict): interventions_dict[variable] = dict() for subvariable_name in intervention_space[ variable]: if subvariable_name == 'cylindrical_position' and \ intervene_on_size: continue if subvariable_name == 'size' and not intervene_on_size: continue interventions_dict[variable][subvariable_name] = \ np.random.uniform( intervention_space [variable][subvariable_name][0], intervention_space [variable][subvariable_name][1]) else: if not intervene_on_joint_positions and variable == 'joint_positions': continue interventions_dict[variable] = np.random.uniform( intervention_space[variable][0], intervention_space[variable][1]) return interventions_dict else: return None
def _init_protocol_helper(self): """ Used by the protocols to initialize some variables further after the environment is passed..etc. :return: """ if self._variable_space == 'space_a_b': self.env.set_intervention_space(variables_space='space_a_b') elif self._variable_space == 'space_a': self.env.set_intervention_space(variables_space='space_a') elif self._variable_space == 'space_b': self.env.set_intervention_space(variables_space='space_b') return