Source code for causal_world.evaluation.evaluation

import os
from causal_world.task_generators.task import generate_task
from causal_world.envs.causalworld import CausalWorld

from causal_world.metrics.mean_last_fractional_success import \
    MeanLastFractionalSuccess
from causal_world.metrics.mean_full_integrated_fractional_success import \
    MeanFullIntegratedFractionalSuccess
from causal_world.metrics.mean_last_integrated_fractional_success import \
    MeanLastIntegratedFractionalSuccess

from causal_world.loggers.data_recorder import DataRecorder
from causal_world.wrappers.protocol_wrapper \
    import ProtocolWrapper
from causal_world.loggers.tracker import Tracker
import json
import logging


[docs]class EvaluationPipeline(object): """ This class provides functionalities to evaluate a trained policy on a set of protocols :param evaluation_protocols: (list) defines the protocols that will be evaluated in this pipleine. :param tracker_path: (causal_world.loggers.Tracker) if a tracker was stored during training this can be passed here. :param world_params: (dict) the world_params to set up the environment, including skip_frame, normalization params..etc. :param task_params: (dict) the task_params of the Task on which the policy is going to be evaluated. :param visualize_evaluation: (bool) if the evaluation is visualized in the GUI. :param initial_seed: (int) the random seed of the evaluation for reproducibility. """
[docs] def __init__(self, evaluation_protocols, tracker_path=None, world_params=None, task_params=None, visualize_evaluation=False, initial_seed=0): self.initial_seed = initial_seed self.data_recorder = DataRecorder(output_directory=None) if tracker_path is not None: self.tracker = Tracker( file_path=os.path.join(tracker_path, 'tracker')) task_stats = self.tracker.task_stats_log[0] del task_stats.task_params['variables_space'] del task_stats.task_params['task_name'] self.task = generate_task(task_generator_id=task_stats.task_name, **task_stats.task_params, variables_space='space_a_b') else: if 'variables_space' in task_params: del task_params['task_name'] del task_params['variables_space'] self.task = generate_task(**task_params, variables_space='space_a_b') if tracker_path: if 'seed' in self.tracker.world_params: del self.tracker.world_params['seed'] if 'wrappers' in self.tracker.world_params: del self.tracker.world_params['wrappers'] self.env = CausalWorld(self.task, **self.tracker.world_params, seed=self.initial_seed, data_recorder=self.data_recorder, enable_visualization=visualize_evaluation) else: if world_params is not None: if 'seed' in world_params: del world_params['seed'] self.env = CausalWorld( self.task, **world_params, seed=self.initial_seed, data_recorder=self.data_recorder, enable_visualization=visualize_evaluation) else: self.env = CausalWorld( self.task, seed=self.initial_seed, data_recorder=self.data_recorder, enable_visualization=visualize_evaluation) evaluation_episode_length_in_secs = self.task.get_default_max_episode_length( ) self.time_steps_for_evaluation = \ int(evaluation_episode_length_in_secs / self.env.dt) self.evaluation_env = self.env self.evaluation_protocols = evaluation_protocols self.metrics_list = [] self.metrics_list.append(MeanFullIntegratedFractionalSuccess()) self.metrics_list.append(MeanLastIntegratedFractionalSuccess()) self.metrics_list.append(MeanLastFractionalSuccess()) return
[docs] def run_episode(self, policy_fn): """ Returns the episode information that is accumulated when running a policy :param policy_fn: (func) the policy_fn that takes an observation as argument and returns the inferred action. :return: (causal_world.loggers.Episode) returns the recorded episode. """ obs = self.evaluation_env.reset() done = False while not done: desired_action = policy_fn(obs) obs, rew, done, info = self.evaluation_env.step(desired_action) return self.data_recorder.get_current_episode()
[docs] def process_metrics(self, episode): """ Processes an episode to compute all the metrics of the evaluation pipeline. :param episode: (causal_world.loggers.Episode) The episode to be processed. :return: (None) """ for metric in self.metrics_list: metric.process_episode(episode) return
[docs] def get_metric_scores(self): """ Returns the metric scores of all metrics in the evaluation pipeline :return: (dict) a score dictionary containing the score for each metric name as key. """ metrics = dict() for metric in self.metrics_list: mean, std = metric.get_metric_score() metrics['mean_' + metric.name] = mean metrics['std_' + metric.name] = std return metrics
[docs] def reset_metric_scores(self): """ Resets the metric scores of each metric object :return: """ for metric in self.metrics_list: metric.reset()
[docs] def evaluate_policy(self, policy, fraction=1): """ Runs the evaluation of a policy and returns a evaluation dictionary with all the scores for each metric for each protocol. :param policy: (func) the policy_fn that takes an observation as argument and returns the inferred action :param fraction: (float) fraction of episodes to be evaluated w.r.t default (can be higher than one). :return: (dict) scores dict for each metric for each protocol. """ pipeline_scores = dict() for evaluation_protocol in self.evaluation_protocols: logging.info('Applying the following protocol now, ' + str(evaluation_protocol.get_name())) self.evaluation_env = ProtocolWrapper(self.env, evaluation_protocol) evaluation_protocol.init_protocol(env=self.env, tracker=self.env.get_tracker(), fraction=fraction) episodes_in_protocol = evaluation_protocol.get_num_episodes() for _ in range(episodes_in_protocol): current_episode = self.run_episode(policy) self.process_metrics(current_episode) self.data_recorder.clear_recorder() scores = self.get_metric_scores() scores['total_intervention_steps'] = \ self.env.get_tracker().get_total_intervention_steps() scores['total_interventions'] = \ self.env.get_tracker().get_total_interventions() scores['total_timesteps'] = \ self.env.get_tracker().get_total_time_steps() scores['total_resets'] = \ self.env.get_tracker().get_total_resets() pipeline_scores[evaluation_protocol.get_name()] = scores self.reset_metric_scores() self.evaluation_env.close() self.pipeline_scores = pipeline_scores return pipeline_scores
[docs] def save_scores(self, evaluation_path, prefix=None): """ Saves the scores dict as json :param evaluation_path: (str) the path where the scores are saved. :param prefix: (str) an optional prefix to the file name. :return: """ if not os.path.isdir(evaluation_path): os.makedirs(evaluation_path) if prefix is None: file_path = os.path.join(evaluation_path, 'scores.json') else: file_path = os.path.join(evaluation_path, '{}_scores.json'.format(prefix)) with open(file_path, "w") as json_file: json.dump(self.pipeline_scores, json_file, indent=4)