xopr.matlab_attribute_utils

  1from collections.abc import Iterable
  2
  3import h5py
  4import numpy as np
  5import scipy.io
  6
  7#
  8# HDF5-format MATLAB files
  9#
 10
 11def dereference_h5value(value, h5file, make_array=True):
 12    if isinstance(value, h5py.Reference):
 13        return dereference_h5value(h5file[value], h5file=h5file)
 14    elif isinstance(value, h5py.Group):
 15        # Pass back to decode_hdf5_matlab_variable to handle groups
 16        return decode_hdf5_matlab_variable(value, h5file=h5file)
 17    elif isinstance(value, Iterable):
 18        v = [dereference_h5value(v, h5file=h5file) for v in value]
 19        if make_array:
 20            try:
 21                return np.squeeze(np.array(v))
 22            except:
 23                return v
 24        else:
 25            return v
 26    elif isinstance(value, np.number):
 27        return value.item()
 28    else:
 29        return value
 30
 31def decode_hdf5_matlab_variable(h5var, skip_variables=False, debug_path="", skip_errors=True, h5file=None):
 32    """
 33    Decode a MATLAB variable stored in an HDF5 file.
 34    This function assumes the variable is stored as a byte string.
 35    """
 36    if h5file is None:
 37        h5file = h5var.file
 38    matlab_class = h5var.attrs.get('MATLAB_class', None)
 39
 40    # Handle MATLAB_class as either bytes or string
 41    if matlab_class and (matlab_class == b'cell' or matlab_class == 'cell'):
 42        return dereference_h5value(h5var[:], h5file=h5file, make_array=False)
 43    elif matlab_class and (matlab_class == b'char' or matlab_class == 'char'):
 44        # Check if this is an empty MATLAB char array
 45        if h5var.attrs.get('MATLAB_empty', 0):
 46            return ''
 47
 48        # MATLAB stores char arrays as uint16 (Unicode code points)
 49        # or sometimes uint8 (ASCII). Handle both cases properly.
 50        data = h5var[:]
 51
 52        if data.dtype == np.dtype('uint16'):
 53            # Each uint16 value is a Unicode code point (UCS-2/UTF-16)
 54            # Convert to string by treating each value as a character code
 55            chars = [chr(c) for c in data.flatten() if c != 0]
 56            return ''.join(chars).rstrip()
 57        elif data.dtype == np.dtype('uint8'):
 58            # uint8 data can be decoded directly as UTF-8
 59            return data.tobytes().decode('utf-8').rstrip('\x00')
 60        else:
 61            # Fallback for unexpected dtypes (including uint64 for empty arrays)
 62            # First check if it's all zeros (empty string)
 63            if np.all(data == 0):
 64                return ''
 65            # Try the old method that may work for some cases
 66            try:
 67                return data.astype(dtype=np.uint8).tobytes().decode('utf-8').rstrip('\x00')
 68            except UnicodeDecodeError:
 69                # If that fails, try to convert assuming Unicode code points
 70                chars = [chr(min(c, 0x10FFFF)) for c in data.flatten() if c != 0]
 71                return ''.join(chars).rstrip()
 72    elif isinstance(h5var, (h5py.Group, h5py.File)):
 73        attrs = {}
 74        for k in h5var:
 75            if k.startswith('#'):
 76                continue
 77            if 'api_key' in k:
 78                attrs[k] = "API_KEY_REMOVED"
 79                continue
 80            if isinstance(h5var[k], h5py.Dataset):
 81                if not skip_variables:
 82                    try:
 83                        attrs[k] = decode_hdf5_matlab_variable(h5var[k], debug_path=debug_path + "/" + k, skip_errors=skip_errors, h5file=h5file)
 84                    except Exception as e:
 85                        print(f"Failed to decode variable {k} at {debug_path}: {e}")
 86                        if not skip_errors:
 87                            raise e
 88            else:
 89                attrs[k] = decode_hdf5_matlab_variable(h5var[k], debug_path=debug_path + "/" + k, skip_errors=skip_errors, h5file=h5file)
 90        return attrs
 91    elif isinstance(h5var, h5py.Dataset):
 92        if h5var.dtype == 'O':
 93            return dereference_h5value(h5var[:], h5file=h5file)
 94        else:
 95            return np.squeeze(h5var[:])
 96    else:
 97        return h5var[:]
 98
 99#
100# Legacy MATLAB files (non-HDF5)
101#
102
103def extract_legacy_mat_attributes(file, skip_keys=[], skip_errors=True):
104    m = scipy.io.loadmat(file, mat_dtype=False, simplify_cells=True, squeeze_me=True)
105
106    attrs = {key: value for key, value in m.items()
107             if not key.startswith('__') and key not in skip_keys}
108
109    attrs = strip_api_key(attrs)
110    attrs = convert_object_ndarrays_to_lists(attrs)
111    return attrs
112
113def strip_api_key(attrs):
114    attrs_clean = {}
115    for key, value in attrs.items():
116        if 'api_key' in key:
117            attrs_clean[key] = "API_KEY_REMOVED"
118        elif isinstance(value, dict):
119            attrs_clean[key] = strip_api_key(value)
120        else:
121            attrs_clean[key] = value
122    return attrs_clean
123
124def convert_object_ndarrays_to_lists(attrs):
125    """
126    Convert any object ndarray attributes to lists.
127    """
128    for key, value in attrs.items():
129        if isinstance(value, np.ndarray) and value.dtype == 'object':
130            attrs[key] = value.tolist()
131        elif isinstance(value, dict):
132            convert_object_ndarrays_to_lists(value)
133        else:
134            attrs[key] = value
135    return attrs
def dereference_h5value(value, h5file, make_array=True):
12def dereference_h5value(value, h5file, make_array=True):
13    if isinstance(value, h5py.Reference):
14        return dereference_h5value(h5file[value], h5file=h5file)
15    elif isinstance(value, h5py.Group):
16        # Pass back to decode_hdf5_matlab_variable to handle groups
17        return decode_hdf5_matlab_variable(value, h5file=h5file)
18    elif isinstance(value, Iterable):
19        v = [dereference_h5value(v, h5file=h5file) for v in value]
20        if make_array:
21            try:
22                return np.squeeze(np.array(v))
23            except:
24                return v
25        else:
26            return v
27    elif isinstance(value, np.number):
28        return value.item()
29    else:
30        return value
def decode_hdf5_matlab_variable( h5var, skip_variables=False, debug_path='', skip_errors=True, h5file=None):
32def decode_hdf5_matlab_variable(h5var, skip_variables=False, debug_path="", skip_errors=True, h5file=None):
33    """
34    Decode a MATLAB variable stored in an HDF5 file.
35    This function assumes the variable is stored as a byte string.
36    """
37    if h5file is None:
38        h5file = h5var.file
39    matlab_class = h5var.attrs.get('MATLAB_class', None)
40
41    # Handle MATLAB_class as either bytes or string
42    if matlab_class and (matlab_class == b'cell' or matlab_class == 'cell'):
43        return dereference_h5value(h5var[:], h5file=h5file, make_array=False)
44    elif matlab_class and (matlab_class == b'char' or matlab_class == 'char'):
45        # Check if this is an empty MATLAB char array
46        if h5var.attrs.get('MATLAB_empty', 0):
47            return ''
48
49        # MATLAB stores char arrays as uint16 (Unicode code points)
50        # or sometimes uint8 (ASCII). Handle both cases properly.
51        data = h5var[:]
52
53        if data.dtype == np.dtype('uint16'):
54            # Each uint16 value is a Unicode code point (UCS-2/UTF-16)
55            # Convert to string by treating each value as a character code
56            chars = [chr(c) for c in data.flatten() if c != 0]
57            return ''.join(chars).rstrip()
58        elif data.dtype == np.dtype('uint8'):
59            # uint8 data can be decoded directly as UTF-8
60            return data.tobytes().decode('utf-8').rstrip('\x00')
61        else:
62            # Fallback for unexpected dtypes (including uint64 for empty arrays)
63            # First check if it's all zeros (empty string)
64            if np.all(data == 0):
65                return ''
66            # Try the old method that may work for some cases
67            try:
68                return data.astype(dtype=np.uint8).tobytes().decode('utf-8').rstrip('\x00')
69            except UnicodeDecodeError:
70                # If that fails, try to convert assuming Unicode code points
71                chars = [chr(min(c, 0x10FFFF)) for c in data.flatten() if c != 0]
72                return ''.join(chars).rstrip()
73    elif isinstance(h5var, (h5py.Group, h5py.File)):
74        attrs = {}
75        for k in h5var:
76            if k.startswith('#'):
77                continue
78            if 'api_key' in k:
79                attrs[k] = "API_KEY_REMOVED"
80                continue
81            if isinstance(h5var[k], h5py.Dataset):
82                if not skip_variables:
83                    try:
84                        attrs[k] = decode_hdf5_matlab_variable(h5var[k], debug_path=debug_path + "/" + k, skip_errors=skip_errors, h5file=h5file)
85                    except Exception as e:
86                        print(f"Failed to decode variable {k} at {debug_path}: {e}")
87                        if not skip_errors:
88                            raise e
89            else:
90                attrs[k] = decode_hdf5_matlab_variable(h5var[k], debug_path=debug_path + "/" + k, skip_errors=skip_errors, h5file=h5file)
91        return attrs
92    elif isinstance(h5var, h5py.Dataset):
93        if h5var.dtype == 'O':
94            return dereference_h5value(h5var[:], h5file=h5file)
95        else:
96            return np.squeeze(h5var[:])
97    else:
98        return h5var[:]

Decode a MATLAB variable stored in an HDF5 file. This function assumes the variable is stored as a byte string.

def extract_legacy_mat_attributes(file, skip_keys=[], skip_errors=True):
104def extract_legacy_mat_attributes(file, skip_keys=[], skip_errors=True):
105    m = scipy.io.loadmat(file, mat_dtype=False, simplify_cells=True, squeeze_me=True)
106
107    attrs = {key: value for key, value in m.items()
108             if not key.startswith('__') and key not in skip_keys}
109
110    attrs = strip_api_key(attrs)
111    attrs = convert_object_ndarrays_to_lists(attrs)
112    return attrs
def strip_api_key(attrs):
114def strip_api_key(attrs):
115    attrs_clean = {}
116    for key, value in attrs.items():
117        if 'api_key' in key:
118            attrs_clean[key] = "API_KEY_REMOVED"
119        elif isinstance(value, dict):
120            attrs_clean[key] = strip_api_key(value)
121        else:
122            attrs_clean[key] = value
123    return attrs_clean
def convert_object_ndarrays_to_lists(attrs):
125def convert_object_ndarrays_to_lists(attrs):
126    """
127    Convert any object ndarray attributes to lists.
128    """
129    for key, value in attrs.items():
130        if isinstance(value, np.ndarray) and value.dtype == 'object':
131            attrs[key] = value.tolist()
132        elif isinstance(value, dict):
133            convert_object_ndarrays_to_lists(value)
134        else:
135            attrs[key] = value
136    return attrs

Convert any object ndarray attributes to lists.