Source code for mango.hdf5io

import tables
import numpy as np
from contextlib import suppress
from mango.constants import c

types = (int, float, bool, str, np.bool_,
         np.int8, np.int16, np.int32, np.int64,
         np.uint8, np.uint16, np.uint32, np.uint64,
         np.float16, np.float32, np.float64,
         np.complex64, np.complex128)


[docs]def save(filename, data, compression=None, datalength=None, append=set([]), name="data"): """ Allow incremental saving of data to hdf5 files. Wrapping all information to be stored in a dictionary gives much more customisation of storage names and avoids overwiting data Parameters ---------- filename: string filename for hdf5 file data: nearly anything data to be stored currently works for most types, standard numpy arrays, tuples, lists and dictionaries. compression: tuple or list Compression is set by name, such as blosc and compression level eg. (blosc,9) see the manual for your version of `pytables <http://www.pytables.org/usersguide/optimization.html?#compression-issues>`_ for the full list. datalength: int The length of the array being stored (or a rough guess) specfiy to improve file I/O speeds especially for appending append: set For anything to be saved incrementally a set of variable names needs to be provided. The append list works for any entry and all sub entries eg. append=set(["energy"]) would apply to both ./energy and ./energy/\*. Attributes such as single number values not in arrays will always be overwritten name: string This is not used for dictionaries. Otherwise the default name is data. """ if compression is not None: filters = tables.Filters(complib=compression[0], complevel=compression[1], shuffle=True, bitshuffle=True if compression[0] == 'blosc' else False) else: filters = None with tables.open_file(filename, "a") as hdf5: location = hdf5.root if isinstance(data, dict): for key, value in data.items(): _save_type(hdf5, location, value, key, filters, datalength, append) else: _save_type(hdf5, location, data, name, filters, datalength, append)
def _save_type(file, location, data, name, filters, datalength, append): backupcheck(location, name) if data is None: with suppress(tables.exceptions.NodeError): file.create_group(location, name, "nonetype") elif isinstance(data, dict) or isinstance(data, list) or isinstance(data, tuple): _save_ltd(file, location=location, data=data, name=name, filters=filters, datalength=datalength, append=append, ltd=str(type(data)).split("'")[1]) elif isinstance(data, np.ndarray): _save_numpy(file, location=location, data=data, name=name, filters=filters, datalength=datalength, append=append) elif isinstance(data, types): setattr(location._v_attrs, name, data) else: c.Error("W Saving {}{} not yet implemented, sorry".format(name, type(data))) def _save_ltd(file, location, data, name, filters, datalength, append, ltd='list'): if 'dict' in ltd and data == {}: return try: new_entry = getattr(location, name) if new_entry._v_title == "nonetype": new_entry._v_title = ltd except tables.exceptions.NoSuchNodeError: new_entry = file.create_group(location, name, ltd) for key, value in enumerate(data) if ltd in ['tuple', 'list'] else data.items(): key = f"a{key}" if ltd in ['tuple', 'list'] else key _save_type(file, location=new_entry, data=value, name=key, filters=filters, datalength=datalength, append=append) def _save_numpy(file, location, data, name, filters, datalength, append): if np.isscalar(data): setattr(location._v_attrs, name, data) return try: node = getattr(location, name) if node._v_title == 'nonetype': node._f_remove() _create_array(file, location, data, name, filters, datalength, append) elif _append_check(node, name, append): node.append(data) # else: # WARNING numpy arrays are not updated where as attrs are except tables.exceptions.NoSuchNodeError: _create_array(file, location, data, name, filters, datalength, append) def _create_array(file, location, data, name, filters, datalength, append): atom = tables.Atom.from_dtype(data.dtype) shape = list(data.shape) num_rows = shape[0] eshape = shape.copy() eshape[0] = 0 if filters is not None and datalength > 300 and name in append: node = file.create_earray(location, name, atom=atom, shape=eshape, expectedrows=datalength, filters=filters, chunkshape=None) else: node = file.create_earray(location, name, atom=atom, shape=eshape, expectedrows=datalength if name in append else num_rows, chunkshape=shape if shape != [0] else None) node.append(data) def _append_check(node, name, append): nodename = node while True: if name in append: check = True break else: nodename = nodename._v_parent name = str(nodename).split()[0].split("/")[-1] if name == "": check = False break return check
[docs]def backupcheck(location, name): """Backup variables.""" if location._v_name.split('/')[-1] == 'vars' and name in ['RandNoState', 'extra_iter', 'written', 'SCFcount']: backup(location, name)
[docs]def backup(location, name): """ Backup useful variables. Parameters ---------- location: node name: str """ restart_no = 1 newname = name + "{}" try: oldvar = getattr(location, name) while hasattr(location, newname.format(restart_no)): c.verb.print(newname.format(restart_no)) restart_no += 1 oldvar._f_rename(newname.format(restart_no)) except tables.exceptions.NoSuchNodeError: with suppress(AttributeError): oldvar = getattr(location._v_attrs, name) while hasattr(location._v_attrs, newname.format(restart_no)): c.verb.print(newname.format(restart_no)) restart_no += 1 setattr(location._v_attrs, newname.format(restart_no), oldvar) location._f_delattr(name)
[docs]def annihilate(filename, location): nodedata = load(filename, location) with tables.open_file(filename, "a") as hdf5: node, sl, attr_root = get_allattr(hdf5.root, location) if isinstance(node, tables.attributeset.AttributeSet): delattr(attr_root, sl[-1]) else: node.remove() return nodedata
def _mv(filename, location, newlocation): with tables.open_file(filename, 'a') as hdf5: node, sl, attr_root = get_allattr(hdf5, location) if isinstance(node, tables.attributeset.AttributeSet): setattr(attr_root, newlocation.rsplit('/', 1)[-1], node) delattr(attr_root, sl[-1]) else: node._f_rename(newlocation.rsplit('/', 1)[-1])
[docs]def get_allattr(file, location): split_loc = location.rsplit('/', 1) try: node = getattr(file, location) attr_root = None except tables.exceptions.NoSuchNodeError: if len(split_loc) > 1: attr_root = getattr(file, split_loc[0])._v_attrs node = attr_root[split_loc[-1]] else: attr_root = file._v_attrs node = attr_root[location] return node, split_loc, attr_root
[docs]def load(filename, location=None, chunk=None, keylist=False): """ Load hdf5 datasets in (hopefully) the same format. Parameters ---------- filename: string file to load location: location of data within file eg: "data/movement/position" """ with tables.open_file(filename, mode='r') as hdf5: if location is None: data = _load_type(hdf5, hdf5.root, chunk, keylist) else: # replace getattr with get_allattr when properly tested data = _load_type(hdf5, getattr(hdf5.root, location), chunk, keylist) return data
def _load_type(filename, location, chunk=None, keylist=False): if isinstance(location, tables.Group): store = {} for loc in location: newloc = _load_type(filename, loc, chunk, keylist) n = loc._v_name store[n] = newloc # Attributes overwrite nodes with the same name for name in location._v_attrs._f_list(): v = location._v_attrs[name] store[name] = v return (None if not store else _return_list(store, location._v_title) if location._v_title.startswith(('tuple', 'list')) else store) elif isinstance(location, tables.Array): if keylist: return {'shape': location.shape, 'chunkread': location.chunkshape} if chunk is None: return location[:] else: md = location.maindim high = chunkindex(location.shape[md], location.chunkshape[md], chunk[1]) low = chunkindex(location.shape[md], location.chunkshape[md], chunk[0]) # TODO assumes index 0 is extensible axis return location[low:high] else: c.Error("W Loading {} not yet implemented, sorry".format(type(location)))
[docs]def chunkindex(loc_s1, cs_ind1, c1): """Return positive array index for numpy array.""" return (None if c1 is None else int(cs_ind1 * c1) if c1 >= 0 else int(((loc_s1 // cs_ind1) - abs(c1 + 1 if loc_s1 % cs_ind1 > 0 else c1)) * cs_ind1))
def _return_list(data, lort): l_data = [] for i in range(len(data)): l_data += [data[f'a{i}']] return l_data if lort.startswith('list') else tuple(l_data)