import tables
import numpy as np
from contextlib import suppress
from mango.constants import c
types = (int, float, bool, str, np.bool_,
np.int8, np.int16, np.int32, np.int64,
np.uint8, np.uint16, np.uint32, np.uint64,
np.float16, np.float32, np.float64,
np.complex64, np.complex128)
[docs]def save(filename, data, compression=None, datalength=None, append=set([]), name="data"):
"""
Allow incremental saving of data to hdf5 files.
Wrapping all information to be stored in a dictionary gives much more customisation of
storage names and avoids overwiting data
Parameters
----------
filename: string
filename for hdf5 file
data: nearly anything
data to be stored currently works for most types, standard numpy arrays,
tuples, lists and dictionaries.
compression: tuple or list
Compression is set by name, such as blosc and compression level
eg. (blosc,9)
see the manual for your version of `pytables <http://www.pytables.org/usersguide/optimization.html?#compression-issues>`_
for the full list.
datalength: int
The length of the array being stored (or a rough guess)
specfiy to improve file I/O speeds especially for appending
append: set
For anything to be saved incrementally a set of variable names needs to be provided.
The append list works for any entry and all sub entries
eg. append=set(["energy"]) would apply to both ./energy and ./energy/\*.
Attributes such as single number values not in arrays will always be overwritten
name: string
This is not used for dictionaries. Otherwise the default name is data.
"""
if compression is not None:
filters = tables.Filters(complib=compression[0], complevel=compression[1],
shuffle=True, bitshuffle=True if compression[0] == 'blosc' else False)
else:
filters = None
with tables.open_file(filename, "a") as hdf5:
location = hdf5.root
if isinstance(data, dict):
for key, value in data.items():
_save_type(hdf5, location, value, key, filters, datalength, append)
else:
_save_type(hdf5, location, data, name, filters, datalength, append)
def _save_type(file, location, data, name, filters, datalength, append):
backupcheck(location, name)
if data is None:
with suppress(tables.exceptions.NodeError):
file.create_group(location, name, "nonetype")
elif isinstance(data, dict) or isinstance(data, list) or isinstance(data, tuple):
_save_ltd(file, location=location, data=data, name=name,
filters=filters, datalength=datalength, append=append, ltd=str(type(data)).split("'")[1])
elif isinstance(data, np.ndarray):
_save_numpy(file, location=location, data=data, name=name,
filters=filters, datalength=datalength, append=append)
elif isinstance(data, types):
setattr(location._v_attrs, name, data)
else:
c.Error("W Saving {}{} not yet implemented, sorry".format(name, type(data)))
def _save_ltd(file, location, data, name, filters, datalength, append, ltd='list'):
if 'dict' in ltd and data == {}:
return
try:
new_entry = getattr(location, name)
if new_entry._v_title == "nonetype":
new_entry._v_title = ltd
except tables.exceptions.NoSuchNodeError:
new_entry = file.create_group(location, name, ltd)
for key, value in enumerate(data) if ltd in ['tuple', 'list'] else data.items():
key = f"a{key}" if ltd in ['tuple', 'list'] else key
_save_type(file, location=new_entry, data=value, name=key,
filters=filters, datalength=datalength, append=append)
def _save_numpy(file, location, data, name, filters, datalength, append):
if np.isscalar(data):
setattr(location._v_attrs, name, data)
return
try:
node = getattr(location, name)
if node._v_title == 'nonetype':
node._f_remove()
_create_array(file, location, data, name, filters, datalength, append)
elif _append_check(node, name, append):
node.append(data)
# else:
# WARNING numpy arrays are not updated where as attrs are
except tables.exceptions.NoSuchNodeError:
_create_array(file, location, data, name, filters, datalength, append)
def _create_array(file, location, data, name, filters, datalength, append):
atom = tables.Atom.from_dtype(data.dtype)
shape = list(data.shape)
num_rows = shape[0]
eshape = shape.copy()
eshape[0] = 0
if filters is not None and datalength > 300 and name in append:
node = file.create_earray(location, name, atom=atom,
shape=eshape,
expectedrows=datalength,
filters=filters, chunkshape=None)
else:
node = file.create_earray(location, name, atom=atom,
shape=eshape,
expectedrows=datalength if name in append else num_rows,
chunkshape=shape if shape != [0] else None)
node.append(data)
def _append_check(node, name, append):
nodename = node
while True:
if name in append:
check = True
break
else:
nodename = nodename._v_parent
name = str(nodename).split()[0].split("/")[-1]
if name == "":
check = False
break
return check
[docs]def backupcheck(location, name):
"""Backup variables."""
if location._v_name.split('/')[-1] == 'vars' and name in ['RandNoState', 'extra_iter', 'written', 'SCFcount']:
backup(location, name)
[docs]def backup(location, name):
"""
Backup useful variables.
Parameters
----------
location: node
name: str
"""
restart_no = 1
newname = name + "{}"
try:
oldvar = getattr(location, name)
while hasattr(location, newname.format(restart_no)):
c.verb.print(newname.format(restart_no))
restart_no += 1
oldvar._f_rename(newname.format(restart_no))
except tables.exceptions.NoSuchNodeError:
with suppress(AttributeError):
oldvar = getattr(location._v_attrs, name)
while hasattr(location._v_attrs, newname.format(restart_no)):
c.verb.print(newname.format(restart_no))
restart_no += 1
setattr(location._v_attrs, newname.format(restart_no), oldvar)
location._f_delattr(name)
[docs]def annihilate(filename, location):
nodedata = load(filename, location)
with tables.open_file(filename, "a") as hdf5:
node, sl, attr_root = get_allattr(hdf5.root, location)
if isinstance(node, tables.attributeset.AttributeSet):
delattr(attr_root, sl[-1])
else:
node.remove()
return nodedata
def _mv(filename, location, newlocation):
with tables.open_file(filename, 'a') as hdf5:
node, sl, attr_root = get_allattr(hdf5, location)
if isinstance(node, tables.attributeset.AttributeSet):
setattr(attr_root, newlocation.rsplit('/', 1)[-1], node)
delattr(attr_root, sl[-1])
else:
node._f_rename(newlocation.rsplit('/', 1)[-1])
[docs]def get_allattr(file, location):
split_loc = location.rsplit('/', 1)
try:
node = getattr(file, location)
attr_root = None
except tables.exceptions.NoSuchNodeError:
if len(split_loc) > 1:
attr_root = getattr(file, split_loc[0])._v_attrs
node = attr_root[split_loc[-1]]
else:
attr_root = file._v_attrs
node = attr_root[location]
return node, split_loc, attr_root
[docs]def load(filename, location=None, chunk=None, keylist=False):
"""
Load hdf5 datasets in (hopefully) the same format.
Parameters
----------
filename: string
file to load
location:
location of data within file eg: "data/movement/position"
"""
with tables.open_file(filename, mode='r') as hdf5:
if location is None:
data = _load_type(hdf5, hdf5.root, chunk, keylist)
else:
# replace getattr with get_allattr when properly tested
data = _load_type(hdf5, getattr(hdf5.root, location), chunk, keylist)
return data
def _load_type(filename, location, chunk=None, keylist=False):
if isinstance(location, tables.Group):
store = {}
for loc in location:
newloc = _load_type(filename, loc, chunk, keylist)
n = loc._v_name
store[n] = newloc
# Attributes overwrite nodes with the same name
for name in location._v_attrs._f_list():
v = location._v_attrs[name]
store[name] = v
return (None if not store else _return_list(store, location._v_title)
if location._v_title.startswith(('tuple', 'list')) else store)
elif isinstance(location, tables.Array):
if keylist:
return {'shape': location.shape, 'chunkread': location.chunkshape}
if chunk is None:
return location[:]
else:
md = location.maindim
high = chunkindex(location.shape[md], location.chunkshape[md], chunk[1])
low = chunkindex(location.shape[md], location.chunkshape[md], chunk[0])
# TODO assumes index 0 is extensible axis
return location[low:high]
else:
c.Error("W Loading {} not yet implemented, sorry".format(type(location)))
[docs]def chunkindex(loc_s1, cs_ind1, c1):
"""Return positive array index for numpy array."""
return (None if c1 is None else int(cs_ind1 * c1)
if c1 >= 0 else
int(((loc_s1 // cs_ind1) - abs(c1 + 1 if loc_s1 % cs_ind1 > 0
else c1)) * cs_ind1))
def _return_list(data, lort):
l_data = []
for i in range(len(data)):
l_data += [data[f'a{i}']]
return l_data if lort.startswith('list') else tuple(l_data)