Source code for causal_world.loggers.data_loader

import pickle
import os
import json


[docs]class DataLoader:
[docs] def __init__(self, episode_directory): """ This initializes a data loader that loads recorded episodes using a causal_world.loggers.DataRecorder object. :param episode_directory: (str) directory where it holds all the logged episodes. """ if os.path.isdir(episode_directory): self.episode_directory = episode_directory info_path = os.path.join(self.episode_directory, "info.json") with open(info_path, "r") as json_file: info_dict = json.load(json_file) self.max_episode_index = info_dict["max_episode_index"] self.dumb_frequency = info_dict["dumb_frequency"] else: raise ValueError("data_path does not exist")
[docs] def get_episodes(self, indices): """ :param indices: (list) list of indicies of the episodes. :return: (list) list of causal_world.loggers.Episode objects corresponding to the specified episode indicies. """ episodes = [] for index in indices: episodes.append(self.get_episode(index)) return episodes
[docs] def get_episode(self, index): """ :param index: (int) index of the episode requested. :return: (causal_world.loggers.Episode) corresponding episode requested. """ if index > self.max_episode_index: raise Exception("Episode doesnt exist") infile_index_episode = index % self.dumb_frequency floor_index_episode = index - infile_index_episode ceil_index_episode = floor_index_episode + self.dumb_frequency - 1 if ceil_index_episode > self.max_episode_index: ceil_index_episode = self.max_episode_index episode_file = "episode_{}_{}".format(floor_index_episode, ceil_index_episode) episodes_path = os.path.join(self.episode_directory, episode_file) if os.path.isfile(episodes_path): with open(episodes_path, "rb") as file: episodes = pickle.load(file) return episodes[infile_index_episode] else: raise Exception( "Error: Log file with requested episode does not exist")