Getting Kedro Custom Dataset for SunPy Maps to write to/from S3

Question:

I’m currently attempting to define a custom dataset to read/write .fits files to/from S3 as SunPy Maps.

The closest thing to this already in the data catalog is the pillow.ImageDataSet pillow.ImageDataSet, which supports passing a file object when loading:
https://pillow.readthedocs.io/en/stable/reference/Image.html.

I’m unsure if Maps are flexible enough with inputs to justify a similar approach. My attempts so far at modifying the pillow.ImageDataSet _load method to include

smap = Map(fs_file)
return smap

results in the following error:

DataSetError: Failed while loading data from data set SunPyMapDataSet(filepath=sunspots/data/01_raw/map_sample.fits, protocol=s3, save_args={'overwrite': True}).
Invalid input: <File-like object S3FileSystem, sunspots/data/01_raw/map_sample.fits>

How might I get things working here?

Asked By: Jordan Barlow

||

Answers:

I am unfamiliar with this SunPy library. I think your approach is correct so far.

the fs_file is a handler and you need a correct way to open this file. I think you are getting this error probably because Map(fs_file) isn’t the correct way to load a file.

You should probably look for functions that load Map object from a file.

Answered By: mediumnok

Months ago I wrote a Kedro custom dataset for SunPy using Astropy as an intermediary and forgot to answer this question. It may be worth opening a PR to the new kedro-datasets package for SunPy users.

import warnings
from copy import deepcopy
from pathlib import PurePosixPath
from typing import Any, Dict
import fsspec
from kedro.io.core import (
    AbstractVersionedDataSet,
    DataSetError,
    Version,
    get_filepath_str,
    get_protocol_and_path,
)
import numpy as np
from astropy.io import fits
from sunpy.map import Map


class SunPyMapDataSet(AbstractVersionedDataSet):
    DEFAULT_SAVE_ARGS = {"overwrite": False}

    def __init__(
        self,
        filepath: str,
        save_args: Dict[str, Any] = None,
        version: Version = None,
        credentials: Dict[str, Any] = None,
        fs_args: Dict[str, Any] = None,
    ) -> None:

        _fs_args = deepcopy(fs_args) or {}
        _fs_open_args_load = _fs_args.pop("open_args_load", {})
        _fs_open_args_save = _fs_args.pop("open_args_save", {})
        _credentials = deepcopy(credentials) or {}

        protocol, path = get_protocol_and_path(filepath, version)
        if protocol == "file":
            _fs_args.setdefault("auto_mkdir", True)

        self._protocol = protocol
        self._fs = fsspec.filesystem(self._protocol, **_credentials, **_fs_args)

        super().__init__(
            filepath=PurePosixPath(path),
            version=version,
            exists_function=self._fs.exists,
            glob_function=self._fs.glob,
        )

        self._save_args = deepcopy(self.DEFAULT_SAVE_ARGS)
        if save_args is not None:
            self._save_args.update(save_args)

        _fs_open_args_save.setdefault("mode", "wb")
        self._fs_open_args_load = _fs_open_args_load
        self._fs_open_args_save = _fs_open_args_save

    def _describe(self) -> Dict[str, Any]:
        return dict(
            filepath=self._filepath,
            protocol=self._protocol,
            save_args=self._save_args,
            version=self._version,
        )

    def _load(self) -> Map:
        load_path = get_filepath_str(self._get_load_path(), self._protocol)
        with self._fs.open(load_path, **self._fs_open_args_load) as fs_file:
            file = fits.open(fs_file).copy()
            image_hdu = file[1]
            image_hdu.verify("fix")
            smap = Map((image_hdu.data, image_hdu.header))
            return smap

    def _save(self, data: Map) -> None:
        save_path = get_filepath_str(self._get_save_path(), self._protocol)
        with self._fs.open(save_path, **self._fs_open_args_save) as fs_file:
            hdu = fits.ImageHDU()
            hdu.header = data.fits_header
            hdu.data = data.data
            hdu.writeto(fs_file, **self._save_args)
        self._invalidate_cache()

    def _exists(self) -> bool:
        try:
            load_path = get_filepath_str(self._get_load_path(), self._protocol)
        except DataSetError:
            return False
        return self._fs.exists(load_path)

    def _release(self) -> None:
        super()._release()
        self._invalidate_cache()

    def _invalidate_cache(self) -> None:
        """Invalidate underlying filesystem caches."""
        filepath = get_filepath_str(self._filepath, self._protocol)
        self._fs.invalidate_cache(filepath)

Answered By: Jordan Barlow