Source code for afnio.serialization

import os
import pickle
import zipfile
from typing import IO, Any, BinaryIO, Type, Union

from typing_extensions import TypeAlias, TypeGuard

DEFAULT_PROTOCOL = 2

FILE_LIKE: TypeAlias = Union[str, os.PathLike, BinaryIO, IO[bytes]]


def _is_path(name_or_buffer) -> TypeGuard[Union[str, os.PathLike]]:
    return isinstance(name_or_buffer, (str, os.PathLike))


class _opener:
    def __init__(self, file_like):
        self.file_like = file_like

    def __enter__(self):
        return self.file_like

    def __exit__(self, *args):
        pass


class _open_zipfile_writer_file(_opener):
    def __init__(self, name) -> None:
        self.file_stream = None
        self.name = str(name)
        try:
            self.name.encode("ascii")
        except UnicodeEncodeError:
            # ZipFile only supports ASCII filenames.
            # Use Python's file handling for non-ASCII.
            self.file_stream = open(self.name, mode="wb")
            super().__init__(
                zipfile.ZipFile(
                    self.file_stream, mode="w", compression=zipfile.ZIP_DEFLATED
                )
            )
        else:
            super().__init__(
                zipfile.ZipFile(self.name, mode="w", compression=zipfile.ZIP_DEFLATED)
            )

    def __exit__(self, *args) -> None:
        self.file_like.close()
        if self.file_stream is not None:
            self.file_stream.close()


class _open_zipfile_writer_buffer(_opener):
    def __init__(self, buffer) -> None:
        if not callable(getattr(buffer, "write", None)):
            msg = (
                f"Buffer of {str(type(buffer)).strip('<>')} "
                f"has no callable attribute 'write'"
            )
            if not hasattr(buffer, "write"):
                raise AttributeError(msg)
            raise TypeError(msg)
        self.buffer = buffer
        super().__init__(
            zipfile.ZipFile(self.buffer, mode="w", compression=zipfile.ZIP_DEFLATED)
        )

    def __exit__(self, *args) -> None:
        self.file_like.close()
        self.buffer.flush()


class _open_zipfile_reader(_opener):
    def __init__(self, name_or_buffer) -> None:
        if _is_path(name_or_buffer):
            self.file_like = open(name_or_buffer, "rb")
            self.zipfile = zipfile.ZipFile(self.file_like, "r")
        else:
            self.zipfile = zipfile.ZipFile(name_or_buffer, "r")
        super().__init__(self.zipfile)

    def __exit__(self, *args) -> None:
        self.zipfile.close()
        if _is_path(self.file_like):
            self.file_like.close()


def _open_zipfile_writer(name_or_buffer):
    container: Type[_opener]
    if _is_path(name_or_buffer):
        container = _open_zipfile_writer_file
    else:
        container = _open_zipfile_writer_buffer
    return container(name_or_buffer)


def _save(obj, zip_file, pickle_protocol):
    """Helper function to save objects into a zip file."""
    with zip_file.open("data.pkl", "w") as f:
        pickle.dump(obj, f, protocol=pickle_protocol)


def _check_save_filelike(f):
    if not _is_path(f) and not hasattr(f, "write"):
        raise AttributeError(
            "Expected 'f' to be string, path, or a file-like object with "
            "a 'write' attribute"
        )


[docs] def save( obj: object, f: FILE_LIKE, pickle_protocol: int = DEFAULT_PROTOCOL, ) -> None: """ Saves an object to a disk file using zip compression and pickle serialization. Args: obj: The object to be saved. f: A file-like object (must implement write/flush) or a string or os.PathLike object containing a file name. pickle_protocol: Pickle protocol version. .. note:: A common Afnio convention is to save variables using .hf file extension. Example: >>> # Save to file >>> x = hf.Variable(data="You are a doctor.", role="system prompt") >>> hf.save(x, 'variable.hf') >>> # Save to io.BytesIO buffer >>> buffer = io.BytesIO() >>> hf.save(x, buffer) """ _check_save_filelike(f) with _open_zipfile_writer(f) as opened_zipfile: _save( obj, opened_zipfile, pickle_protocol, ) return
[docs] def load(f: FILE_LIKE) -> Any: """ Loads an object from a disk file using zip compression and pickle serialization. Args: f: A file-like object (must implement `read`) or a string or os.PathLike object containing a file name. Returns: The deserialized object. Example: >>> # Load from file >>> obj = hf.load('model.hf') >>> # Load from io.BytesIO buffer >>> buffer = io.BytesIO() >>> obj = hf.load(buffer) """ with _open_zipfile_reader(f) as zip_reader: if "data.pkl" not in zip_reader.namelist(): raise RuntimeError( "Missing 'data.pkl' in archive. File might be corrupted." ) # Read the serialized object with zip_reader.open("data.pkl", "r") as f: obj = pickle.load(f) return obj