Source code for csdl_alpha.src.operations.tensor.tensordot

from csdl_alpha.src.operations.operation_subclasses import ElementwiseOperation, ComposedOperation
from csdl_alpha.src.graph.operation import Operation, set_properties 
from csdl_alpha.src.graph.variable import Variable
from csdl_alpha.utils.inputs import variablize
import csdl_alpha.utils.testing_utils as csdl_tests
import pytest
from csdl_alpha.utils.typing import VariableLike
from csdl_alpha.src.operations.sum import sum as csdl_sum
from csdl_alpha.src.operations.tensor.expand import expand as csdl_expand

class TensorDot(ComposedOperation):
    def __init__(self, x, y, axes=None):
        super().__init__(x,y)
        self.name = 'tensordot'
        alphabet = 'abcdefghijklmnopqrstuvwxyz'
        # Example: 'exp' stands for 'expanded'
        # x.shape = (3,2), y.shape = (2,5), axes = ([1],[0])
        # in1_str = 'ab', in2_str = 'bc', in2_unique_str = 'c'
        # exp_str = 'abc'
        # exp_shape = (3,2,5)
        # action1 = 'ab->abc'
        # action2 = 'bc->abc'
        # summation_axes = [1]

        rank1 = len(x.shape)
        rank2 = len(y.shape)
        in1_str = alphabet[:rank1]
        in2_str = alphabet[rank1:rank1+rank2]

        if axes is None:
            exp_str = in1_str + in2_str
            exp_shape = x.shape + y.shape
            self.summation_axes = None

        else:
            in2_unique_str = ''.join([in2_str[i] for i in range(rank2) if i not in axes[1]])
            exp_str   = in1_str + in2_unique_str
            exp_shape = x.shape + tuple([y.shape[i] for i in range(rank2) if i not in axes[1]])
        
            # replace subscripts in in2_str with in1_str at common axes locations
            for i in range(rank2):
                if i in axes[1]:
                    index_in_axes = axes[1].index(i)
                    axis_in_in1   = axes[0][index_in_axes]
                    in2_str       = in2_str[:i] + in1_str[axis_in_in1] + in2_str[i+1:]
            self.summation_axes = tuple(axes[0])

        self.exp_shape = exp_shape
        self.action1 = f'{in1_str}->{exp_str}'
        self.action2 = f'{in2_str}->{exp_str}'

    def evaluate_composed(self, x, y):
        return evaluate_tensordot(x, y, 
                                  self.exp_shape, 
                                  self.action1, 
                                  self.action2, 
                                  self.summation_axes)
    
def evaluate_tensordot(x, y, exp_shape, action1, action2, summation_axes):
    expand1 = csdl_expand(x, exp_shape, action=action1)
    expand2 = csdl_expand(y, exp_shape, action=action2)
    out = expand1 * expand2
    if summation_axes is not None:
        # more efficient summation for inner product
        if len(summation_axes) == len(exp_shape):
            out = csdl_sum(out)
        else:
            out = csdl_sum(out, axes=summation_axes)
    return out

[docs]def tensordot(x:VariableLike, y:VariableLike, axes=None)->Variable: ''' Computes the tensor dot product of two tensors x and y along the specified axes. The axes argument is a tuple of two lists, where the corresponding axes of x and y to multiply and sum over are specified. If `axes` is specified, the resulting tensor will have shape equal to the concatenation of the shapes of x and y, with the axes specified removed. For example, if x has shape (3,2) and y has shape (2,5), and axes = ([1],[0]), the result will have shape (3,5). The tensor dot product is a generalization of the inner and outer product operations. If no axes is specified, the resulting tensor is the outer product of x and y having shape (x.shape + y.shape). If x and y have same shape, and the axes is set to ([0,1,...,rank_x], [0,1,...,rank_y]), the result is the scalar inner product of x and y. Note that the rank_x = rank_y = len(x.shape) = len(y.shape). Parameters ---------- x : VariableLike First input tensor. y : VariableLike Second input tensor. axes : tuple of 2 lists, default=None Axes along which to compute the tensor dot product of the input tensors. If not specified, the outer product of x and y is computed. If specified, the axes must be a tuple of 2 lists. The axes must be unique within each list. The axes must be non-negative integers within each list. Each list in the tuple must have the same length. Each corresponding pair of axes for x and y in the 2 lists specified must have equal lengths. Returns ------- Variable Tensor dot product of x and y. Examples -------- >>> recorder = csdl.Recorder(inline = True) >>> recorder.start() >>> x = csdl.Variable(value = np.array([1, 2, 3])) >>> y = csdl.Variable(value = np.array([4, 5])) Outer product of x and y: >>> csdl.tensordot(x, y).value array([[ 4., 5.], [ 8., 10.], [12., 15.]]) Outer product of x and z: >>> z = csdl.Variable(value = np.array([[1, 2], [3, 4]])) >>> csdl.tensordot(x, z).value array([[[ 1., 2.], [ 3., 4.]], <BLANKLINE> [[ 2., 4.], [ 6., 8.]], <BLANKLINE> [[ 3., 6.], [ 9., 12.]]]) Dot product of y and z along one axis (same at matrix product z @ y): >>> csdl.tensordot(y, z, axes=([0], [1])).value array([14., 32.]) Inner product of z and t: >>> t_np = np.array([[5, 6], [7, 8]]) >>> csdl.tensordot(z, t_np, axes=([0,1], [0,1])).value array([70.]) ''' x = variablize(x) y = variablize(y) if axes is not None: if isinstance(axes, tuple): if len(axes) != 2: raise ValueError('`axes` must be a tuple of "two" lists.') if not isinstance(axes[0], list) or not isinstance(axes[1], list): raise ValueError('`axes` must be a tuple of two "lists".') else: raise ValueError('`axes` must be a "tuple" of two lists.') if len(axes[0]) != len(axes[1]): raise ValueError('Each list in `axes` must have the same length.') if len(axes[0]) != len(set(axes[0])) or len(axes[1]) != len(set(axes[1])): raise ValueError('Each list in `axes` must have unique elements.') if not all([isinstance(i, int) for i in axes[0]]) or not all([isinstance(i, int) for i in axes[1]]): raise ValueError('Each element in the lists of `axes` must be an integer.') if not all([i >= 0 for i in axes[0]]) or not all([i >= 0 for i in axes[1]]): raise ValueError('Each element in the lists of `axes` must be non-negative.') if not all([x.shape[i] == y.shape[j] for i,j in zip(axes[0], axes[1])]): raise ValueError('Each corresponding pair of axes in the \ 2 lists of `axes` specified must have equal lengths.') op = TensorDot(x, y, axes=axes) return op.finalize_and_return_outputs()
class TestTensorDot(csdl_tests.CSDLTest): def test_functionality(self,): self.prep() import csdl_alpha as csdl import numpy as np x_val = np.array([1, 2, 3]) y_val = np.array([4, 5]) z_val = np.array([[1, 2], [3, 4]]) t_val = np.array([[5, 6], [7, 8]]) x = csdl.Variable(value = x_val) y = csdl.Variable(value = y_val) z = csdl.Variable(value = z_val) compare_values = [] # Outer product of x and y: compare_values += [csdl_tests.TestingPair(csdl.tensordot(x, y), np.tensordot(x_val, y_val, axes=0))] # Outer product of x and z: compare_values += [csdl_tests.TestingPair(csdl.tensordot(x_val, z), np.tensordot(x_val, z_val, axes=0))] # Dot product of y and z along one axis: s1 = csdl.tensordot(y, z, axes=([0], [1])) t1 = np.tensordot(y_val, z_val, axes=([0], [1])) compare_values += [csdl_tests.TestingPair(s1, t1, tag='s1')] # Inner product of z and t: s2 = csdl.tensordot(z, t_val, axes=([0,1], [0,1])) t2 = np.tensordot(z_val, t_val, axes=([0,1], [0,1])).flatten() compare_values += [csdl_tests.TestingPair(s2, t2)] # Inner product of z and t: s2 = csdl.tensordot(1.0, 2.0) compare_values += [csdl_tests.TestingPair(s2, np.ones((1,1))*2.0)] self.run_tests(compare_values = compare_values,verify_derivatives=True) def test_tensordot_complex(self,): self.prep() import csdl_alpha as csdl import numpy as np def arange_np(*shape, ones = False): shape = tuple(shape) if not ones: val = np.arange(np.prod(shape)).reshape(shape)/np.prod(shape)*20.0+20.0 else: val = np.ones(shape)*2.5 return csdl.Variable(value=val), val grid_n = 4 grid_n1 = grid_n - 1 num_physical_dimensions = 1 quadrature_order = 2 grid_values, grid_values_np = arange_np(grid_n, grid_n, 3, ones=True) grid_values = grid_values*0.25 # values, values_np = arange_np(grid_n1, grid_n1, num_physical_dimensions) quadrature_values, quadrature_values_np = arange_np(grid_n1, grid_n1, quadrature_order**2, 1, ones=False) quadrature_values = quadrature_values.T().reshape(quadrature_values.shape) quadrature_coord_weights, quadrature_coord_weights_np = arange_np(quadrature_order**2,) run_with_loop = False if run_with_loop: # compute the integral values = csdl.Variable(value=np.zeros((grid_n-1, grid_n-1, num_physical_dimensions))) for i in csdl.frange(grid_n-1): for j in csdl.frange(grid_n-1): for k in csdl.frange(quadrature_order**2): values = values.set(csdl.slice[i,j], values[i,j] + quadrature_values[i,j,k]*quadrature_coord_weights[k]) # compute areas of the quadrilaterals output = csdl.Variable(value=np.zeros((grid_n-1, grid_n-1))) for i in csdl.frange(grid_n-1): for j in csdl.frange(grid_n-1): area_1 = csdl.norm(csdl.cross(grid_values[i+1,j]-grid_values[i,j], grid_values[i,j+1]-grid_values[i,j]) + 1e-2)/2 area_2 = csdl.norm(csdl.cross(grid_values[i,j+1]-grid_values[i+1,j+1], grid_values[i+1,j]-grid_values[i+1,j+1]) + 1e-2)/2 output = output.set(csdl.slice[i,j], (area_1+area_2)*values[i,j]) else: values = csdl.tensordot(quadrature_values, quadrature_coord_weights, axes = ([2],[0])) area_1 = csdl.norm(csdl.cross(grid_values[1:,:-1]-grid_values[:-1,:-1], grid_values[:-1,1:]-grid_values[:-1,:-1], axis=2) + 1e-2, axes=(2,))/2 area_2 = csdl.norm(csdl.cross(grid_values[:-1,1:]-grid_values[1:,1:], grid_values[1:,:-1]-grid_values[1:,1:], axis=2) + 1e-2, axes=(2,))/2 output = (area_1+area_2)*values.reshape(area_1.shape) run_timing = False if not run_timing: out_average = csdl.average(output) print(out_average.value) compare_values = [] compare_values += [csdl_tests.TestingPair(out_average, np.ones((1,))*56.65516808, decimal=7)] self.run_tests(compare_values = compare_values, verify_derivatives=True, step_size=1e-8) else: ins = [grid_values, quadrature_values, quadrature_coord_weights] rec = csdl.get_current_recorder() jax_interface = csdl.jax.create_jax_interface( inputs = ins, outputs = [out_average] + [csdl.derivative(out_average,ins, as_block=True)] ) import time start = time.time() outputs = jax_interface({input:input.value for input in ins}) end1 = time.time() outputs = jax_interface({input:input.value for input in ins}) end2 = time.time() print(np.linalg.norm(outputs[out_average])) print('Time taken for first run:', end1-start) print('Time taken for second run:', end2-end1) def test_docstring(self): self.docstest(tensordot) if __name__ == '__main__': test = TestTensorDot() test.overwrite_backend = 'jax' test.test_functionality() test.test_docstring() test.test_tensordot_complex()