Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1import tables 

2import numpy as np 

3from contextlib import suppress 

4from mango.constants import c 

5 

6types = (int, float, bool, str, np.bool_, 

7 np.int8, np.int16, np.int32, np.int64, 

8 np.uint8, np.uint16, np.uint32, np.uint64, 

9 np.float16, np.float32, np.float64, 

10 np.complex64, np.complex128) 

11 

12 

13def save(filename, data, compression=None, datalength=None, append=set([]), name="data"): 

14 """ 

15 Allow incremental saving of data to hdf5 files. 

16 

17 Wrapping all information to be stored in a dictionary gives much more customisation of 

18 storage names and avoids overwiting data 

19 

20 Parameters 

21 ---------- 

22 filename: string 

23 filename for hdf5 file 

24 data: nearly anything 

25 data to be stored currently works for most types, standard numpy arrays, 

26 tuples, lists and dictionaries. 

27 compression: tuple or list 

28 Compression is set by name, such as blosc and compression level 

29 eg. (blosc,9) 

30 see the manual for your version of `pytables <http://www.pytables.org/usersguide/optimization.html?#compression-issues>`_ 

31 for the full list. 

32 datalength: int 

33 The length of the array being stored (or a rough guess) 

34 specfiy to improve file I/O speeds especially for appending 

35 append: set 

36 For anything to be saved incrementally a set of variable names needs to be provided. 

37 The append list works for any entry and all sub entries 

38 eg. append=set(["energy"]) would apply to both ./energy and ./energy/\*. 

39 Attributes such as single number values not in arrays will always be overwritten 

40 name: string 

41 This is not used for dictionaries. Otherwise the default name is data. 

42 

43 """ 

44 if compression is not None: 

45 filters = tables.Filters(complib=compression[0], complevel=compression[1], 

46 shuffle=True, bitshuffle=True if compression[0] == 'blosc' else False) 

47 else: 

48 filters = None 

49 

50 with tables.open_file(filename, "a") as hdf5: 

51 location = hdf5.root 

52 if isinstance(data, dict): 

53 for key, value in data.items(): 

54 _save_type(hdf5, location, value, key, filters, datalength, append) 

55 else: 

56 _save_type(hdf5, location, data, name, filters, datalength, append) 

57 

58 

59def _save_type(file, location, data, name, filters, datalength, append): 

60 backupcheck(location, name) 

61 

62 if data is None: 62 ↛ 63line 62 didn't jump to line 63, because the condition on line 62 was never true

63 with suppress(tables.exceptions.NodeError): 

64 file.create_group(location, name, "nonetype") 

65 elif isinstance(data, dict) or isinstance(data, list) or isinstance(data, tuple): 

66 _save_ltd(file, location=location, data=data, name=name, 

67 filters=filters, datalength=datalength, append=append, ltd=str(type(data)).split("'")[1]) 

68 elif isinstance(data, np.ndarray): 

69 _save_numpy(file, location=location, data=data, name=name, 

70 filters=filters, datalength=datalength, append=append) 

71 elif isinstance(data, types): 

72 setattr(location._v_attrs, name, data) 

73 else: 

74 c.Error("W Saving {}{} not yet implemented, sorry".format(name, type(data))) 

75 

76 

77def _save_ltd(file, location, data, name, filters, datalength, append, ltd='list'): 

78 if 'dict' in ltd and data == {}: 

79 return 

80 try: 

81 new_entry = getattr(location, name) 

82 if new_entry._v_title == "nonetype": 

83 new_entry._v_title = ltd 

84 except tables.exceptions.NoSuchNodeError: 

85 new_entry = file.create_group(location, name, ltd) 

86 

87 for key, value in enumerate(data) if ltd in ['tuple', 'list'] else data.items(): 

88 key = f"a{key}" if ltd in ['tuple', 'list'] else key 

89 _save_type(file, location=new_entry, data=value, name=key, 

90 filters=filters, datalength=datalength, append=append) 

91 

92 

93def _save_numpy(file, location, data, name, filters, datalength, append): 

94 

95 if np.isscalar(data): 

96 setattr(location._v_attrs, name, data) 

97 return 

98 

99 try: 

100 node = getattr(location, name) 

101 if node._v_title == 'nonetype': 

102 node._f_remove() 

103 _create_array(file, location, data, name, filters, datalength, append) 

104 elif _append_check(node, name, append): 

105 node.append(data) 

106 # else: 

107 # WARNING numpy arrays are not updated where as attrs are 

108 except tables.exceptions.NoSuchNodeError: 

109 _create_array(file, location, data, name, filters, datalength, append) 

110 

111 

112def _create_array(file, location, data, name, filters, datalength, append): 

113 atom = tables.Atom.from_dtype(data.dtype) 

114 shape = list(data.shape) 

115 num_rows = shape[0] 

116 eshape = shape.copy() 

117 eshape[0] = 0 

118 if filters is not None and datalength > 300 and name in append: 

119 node = file.create_earray(location, name, atom=atom, 

120 shape=eshape, 

121 expectedrows=datalength, 

122 filters=filters, chunkshape=None) 

123 else: 

124 node = file.create_earray(location, name, atom=atom, 

125 shape=eshape, 

126 expectedrows=datalength if name in append else num_rows, 

127 chunkshape=shape if shape != [0] else None) 

128 

129 node.append(data) 

130 

131 

132def _append_check(node, name, append): 

133 nodename = node 

134 while True: 

135 if name in append: 

136 check = True 

137 break 

138 else: 

139 nodename = nodename._v_parent 

140 name = str(nodename).split()[0].split("/")[-1] 

141 if name == "": 

142 check = False 

143 break 

144 return check 

145 

146 

147def backupcheck(location, name): 

148 """Backup variables.""" 

149 if location._v_name.split('/')[-1] == 'vars' and name in ['RandNoState', 'extra_iter', 'written', 'SCFcount']: 

150 backup(location, name) 

151 

152 

153def backup(location, name): 

154 """ 

155 Backup useful variables. 

156 

157 Parameters 

158 ---------- 

159 location: node 

160 name: str 

161 

162 """ 

163 restart_no = 1 

164 newname = name + "{}" 

165 try: 

166 oldvar = getattr(location, name) 

167 while hasattr(location, newname.format(restart_no)): 

168 c.verb.print(newname.format(restart_no)) 

169 restart_no += 1 

170 oldvar._f_rename(newname.format(restart_no)) 

171 except tables.exceptions.NoSuchNodeError: 

172 with suppress(AttributeError): 

173 oldvar = getattr(location._v_attrs, name) 

174 while hasattr(location._v_attrs, newname.format(restart_no)): 

175 c.verb.print(newname.format(restart_no)) 

176 restart_no += 1 

177 setattr(location._v_attrs, newname.format(restart_no), oldvar) 

178 location._f_delattr(name) 

179 

180 

181def annihilate(filename, location): 

182 nodedata = load(filename, location) 

183 with tables.open_file(filename, "a") as hdf5: 

184 node, sl, attr_root = get_allattr(hdf5.root, location) 

185 if isinstance(node, tables.attributeset.AttributeSet): 

186 delattr(attr_root, sl[-1]) 

187 else: 

188 node.remove() 

189 return nodedata 

190 

191 

192def _mv(filename, location, newlocation): 

193 with tables.open_file(filename, 'a') as hdf5: 

194 node, sl, attr_root = get_allattr(hdf5, location) 

195 if isinstance(node, tables.attributeset.AttributeSet): 

196 setattr(attr_root, newlocation.rsplit('/', 1)[-1], node) 

197 delattr(attr_root, sl[-1]) 

198 else: 

199 node._f_rename(newlocation.rsplit('/', 1)[-1]) 

200 

201 

202def get_allattr(file, location): 

203 split_loc = location.rsplit('/', 1) 

204 try: 

205 node = getattr(file, location) 

206 attr_root = None 

207 except tables.exceptions.NoSuchNodeError: 

208 if len(split_loc) > 1: 

209 attr_root = getattr(file, split_loc[0])._v_attrs 

210 node = attr_root[split_loc[-1]] 

211 else: 

212 attr_root = file._v_attrs 

213 node = attr_root[location] 

214 return node, split_loc, attr_root 

215 

216 

217def load(filename, location=None, chunk=None, keylist=False): 

218 """ 

219 Load hdf5 datasets in (hopefully) the same format. 

220 

221 Parameters 

222 ---------- 

223 filename: string 

224 file to load 

225 location: 

226 location of data within file eg: "data/movement/position" 

227 

228 """ 

229 with tables.open_file(filename, mode='r') as hdf5: 

230 if location is None: 

231 data = _load_type(hdf5, hdf5.root, chunk, keylist) 

232 else: 

233 # replace getattr with get_allattr when properly tested 

234 data = _load_type(hdf5, getattr(hdf5.root, location), chunk, keylist) 

235 return data 

236 

237 

238def _load_type(filename, location, chunk=None, keylist=False): 

239 if isinstance(location, tables.Group): 

240 store = {} 

241 for loc in location: 

242 newloc = _load_type(filename, loc, chunk, keylist) 

243 n = loc._v_name 

244 store[n] = newloc 

245 

246 # Attributes overwrite nodes with the same name 

247 for name in location._v_attrs._f_list(): 

248 v = location._v_attrs[name] 

249 store[name] = v 

250 

251 return (None if not store else _return_list(store, location._v_title) 

252 if location._v_title.startswith(('tuple', 'list')) else store) 

253 

254 elif isinstance(location, tables.Array): 

255 if keylist: 

256 return {'shape': location.shape, 'chunkread': location.chunkshape} 

257 if chunk is None: 

258 return location[:] 

259 else: 

260 md = location.maindim 

261 high = chunkindex(location.shape[md], location.chunkshape[md], chunk[1]) 

262 low = chunkindex(location.shape[md], location.chunkshape[md], chunk[0]) 

263 # TODO assumes index 0 is extensible axis 

264 return location[low:high] 

265 

266 else: 

267 c.Error("W Loading {} not yet implemented, sorry".format(type(location))) 

268 

269 

270def chunkindex(loc_s1, cs_ind1, c1): 

271 """Return positive array index for numpy array.""" 

272 return (None if c1 is None else int(cs_ind1 * c1) 

273 if c1 >= 0 else 

274 int(((loc_s1 // cs_ind1) - abs(c1 + 1 if loc_s1 % cs_ind1 > 0 

275 else c1)) * cs_ind1)) 

276 

277 

278def _return_list(data, lort): 

279 l_data = [] 

280 for i in range(len(data)): 

281 l_data += [data[f'a{i}']] 

282 return l_data if lort.startswith('list') else tuple(l_data)