Support indexing directly on Schem object

Allows making sub-schems easily
Slice indexing is supported too
master
Gaël de Sailly 2019-01-27 12:18:01 +01:00
parent 5af8f419b2
commit 7787d1e171
1 changed files with 32 additions and 9 deletions

View File

@ -9,15 +9,28 @@ from io import BytesIO
bulk_dtype = np.dtype([("node", ">u2"), ("prob", "u1"), ("force", "?"), ("param2", "u1")])
class Schem:
def __init__(self, *args):
if isinstance(args[0], str):
self.load(args[0])
else:
shape = tuple(args[:3])
self.version = 4
self.yprobs = np.zeros(shape[1], dtype="u1")
self.nodes = ["air"]
self.data = np.zeros(shape, dtype=bulk_dtype)
def __init__(self, *args, **kwargs):
if len(args) >= 1:
if isinstance(args[0], str):
self.load(args[0])
else:
if isinstance(args[0], tuple):
shape = args
else:
shape = tuple(args[:3])
self.version = 4
self.yprobs = np.zeros(shape[1], dtype="u1")
self.nodes = ["air"]
self.data = np.zeros(shape, dtype=bulk_dtype)
if 'version' in kwargs:
self.version = kwargs['version']
if 'yprobs' in kwargs:
self.yprobs = kwargs['yprobs']
if 'nodes' in kwargs:
self.nodes = kwargs['nodes']
if 'data' in kwargs:
self.data = kwargs['data']
def load(self, filename):
f = open(filename, "rb")
@ -98,3 +111,13 @@ class Schem:
self.nodes = new_nodelist
self.data["node"] = transform_list[self.data["node"]]
def __getitem__(self, slices):
data = self.data[slices].copy()
nodes = self.nodes[:]
if isinstance(slices, tuple) and len(slices) >= 2:
yprobs = self.yprobs[slices[1]].copy()
else:
yprobs = self.yprobs.copy()
return Schem(data=data, nodes=nodes, yprobs=yprobs, version=self.version)