Source code for WaveBlocksND.TensorProductGrid

r"""The WaveBlocks Project

This file contains a class for representing
dense regular tensor product grids.

@author: R. Bourquin
@copyright: Copyright (C) 2012, 2014 R. Bourquin
@license: Modified BSD License
"""

import operator
from numpy import array, atleast_1d, complexfloating, diff, floating, hstack, mgrid, ogrid, squeeze

from WaveBlocksND.DenseGrid import DenseGrid
from functools import reduce

__all__ = ["TensorProductGrid"]


[docs]class TensorProductGrid(DenseGrid): r"""This class represents a dense tensor product grid. It can have an arbitrary dimension :math:`D`. The grid nodes are enclosed in a hypercubic bounding box. This box can have different limits :math:`min_i`, :math:`max_i` along each axis :math:`x_i`. In each of these intervals we place :math:`N_i` grid nodes. Note that the point :math:`max_i` is not part of the grid. The grid interior is build as the tensor product of all the grid nodes along all the axes. """ def __init__(self, limits, number_nodes): r"""Construct a tensor product grid instance. :param limits: The grid domain limits along each axis. :type limits: A list of two-element tuples. :param number_nodes: The number of grid nodes along each axis. :type number_nodes: A list of positive integers. :return: A :py:class:`TensorProductGrid` instance. """ assert len(limits) == len(number_nodes) # Regular grid spacing self._is_regular = True # The dimension of the grid self._dimension = len(limits) # The number of grid nodes along each axis self._number_nodes = number_nodes # format: [N_1, ..., N_D] # The limits of the bounding box of the grid self._limits = [array(limit) for limit in limits] # format: [(min_0,max_0), ..., (min_D,max_D)] # The extensions (edge length) of the bounding box self._extensions = hstack([abs(diff(limit)) for limit in self._limits]) # Compute the grid spacings along each axis self._meshwidths = self._extensions / squeeze(array(self._number_nodes, dtype=floating)) # format: [h_1, ..., h_D] # Cached values self._gridaxes = None self._gridnodes = None
[docs] def get_limits(self, axes=None): r"""Returns the limits of the bounding box. :param axes: The axes for which we want to get the limits. :type axes: A single integer or a list of integers. If set to ``None`` (default) we return the limits for all axes. :return: A list of :math:`(min_i, max_i)` ndarrays. """ if axes is None: axes = range(self._dimension) return [self._limits[i] for i in atleast_1d(axes)]
[docs] def get_extensions(self, axes=None): r"""Returns the extensions (length of the edges) of the bounding box. :param axes: The axes for which we want to get the extensions. :type axes: A single integer or a list of integers. If set to ``None`` (default) we return the extensions for all axes. :return: A list of :math:`|max_i-min_i|` values. """ if axes is None: axes = range(self._dimension) return [self._extensions[i] for i in atleast_1d(axes)]
[docs] def get_meshwidths(self, axes=None): r"""Returns the meshwidths of the grid. :param axes: The axes for which we want to get the meshwidths. :type axes: A single integer or a list of integers. If set to ``None`` (default) we return the data for all axes. :return: A list of :math:`h_i` values or a single value. """ if axes is None: axes = range(self._dimension) return [self._meshwidths[i] for i in atleast_1d(axes)]
[docs] def get_number_nodes(self, axes=None, overall=False): r"""Returns the number of grid nodes along a set of axes. :param axes: The axes for which we want to get the number of nodes. :type axes: A single integer or a list of integers. If set to ``None`` (default) we return the data for all axes. :param overall: Compute the product :math:`\prod_i^D N_i` of the number :math:`N_i` of grid nodes along each axis :math:`i` specified. :type overall: Boolean, default is ``False`` :return: A list of :math:`N_i` values or a single value :math:`N`. """ if axes is None: axes = range(self._dimension) values = [self._number_nodes[i] for i in atleast_1d(axes)] if overall is True: return reduce(operator.mul, values) else: return values
def _build_slicers(self): # Helper routine to build the necessary slicing # objects used for constructing the grid nodes. slicers = [slice(lims[0], lims[1], step) for lims, step in zip(self._limits, self._meshwidths)] return slicers def _compute_grid_axes(self): # Helper routine which computes the one-dimensional # grids along all axes and caches the result. Each # grid is a D-dimensional ndarray of correct shape. S = self._build_slicers() self._gridaxes = [array(ax, dtype=complexfloating) for ax in ogrid[S]] def _compute_grid_full(self): # Helper routine which computes the full set of # tensor product grid nodes and caches the result. # The result is a (D, product(N_i)) shaped ndarray. S = self._build_slicers() # TODO: Code is 4x slower withOUT complex floating grid = array(mgrid[S], dtype=complexfloating) self._gridnodes = grid.reshape((self._dimension, self.get_number_nodes(overall=True)))
[docs] def get_axes(self, axes=None): r"""Returns the one-dimensional grids along the axes. :param axes: The axes for which we want to get the grid. :type axes: A single integer or a list of integers. If set to ``None`` (default) we return the data for all axes. :return: A list of ndarrays, each having a shape of :math:`(1,...,N_i,...,1)`. We return a list even if it contains just a single element. """ if self._gridaxes is None: self._compute_grid_axes() if axes is None: axes = range(self._dimension) axes = atleast_1d(axes) return [self._gridaxes[i] for i in axes]
[docs] def get_nodes(self, flat=True, split=False): r"""Returns all grid nodes of the full tensor product grid. :param flat: Whether to return the grid with a `hypercubic` :math:`(D, N_1, ..., N_D)` shape or a `flat` :math:`(D, \prod_i^D N_i)` shape. :type flat: Boolean, default is ``True``. :param split: Whether to return the different components, one for each dimension inside a single ndarray or a list with ndarrays, with one item per dimension. :type split: Boolean, default is ``False``. :return: Depends of the optional arguments. """ if self._gridnodes is None: self._compute_grid_full() # All operations here only return views to the original grid data if flat is True: if split is False: return self._gridnodes else: return tuple([self._gridnodes[i, :] for i in range(self._dimension)]) else: if split is False: return self._gridnodes.reshape([self._dimension] + self.get_number_nodes()) else: return tuple([self._gridnodes[i, :].reshape(self.get_number_nodes()) for i in range(self._dimension)])