Source code for csdl_alpha.src.operations.maximum

from csdl_alpha.src.graph.operation import Operation, set_properties
from csdl_alpha.src.graph.variable import Variable
from csdl_alpha.utils.inputs import variablize, validate_and_variablize
import csdl_alpha.utils.testing_utils as csdl_tests
import csdl_alpha as csdl

import numpy as np

class Maximum(Operation):
    '''
    Maximum entries in the input tensor along the specified axes.
    '''
    def __init__(self, x, axes=None, out_shape=None, rho=20.):
        super().__init__(x)
        self.name  = 'maximum'
        out_shapes = (out_shape,)
        self.set_dense_outputs(out_shapes)
        self.axes  = axes
        self.rho   = rho
        in_shape = x.shape

        if axes is not None:
            axes = self.axes  = tuple(np.sort(axes))
            rank = len(in_shape)
            alphabet  = 'abcdefghijklmnopqrstuvwxyz'
            in1_str = alphabet[:axes[0]]
            in2_str = alphabet[axes[0]]
            ones_shape = (in_shape[axes[0]],)
            for i in range(len(axes)-1):
                in1_str += alphabet[axes[i] + 1 : axes[i + 1]]
                in2_str += alphabet[axes[i+1]]
                ones_shape += (in_shape[axes[i+1]],)
            in1_str += alphabet[axes[-1] + 1 : rank]
            self.einsum_str = '{},{}->{}'.format(
                in1_str,
                in2_str,
                alphabet[:rank],
                )
            self.ones_shape = ones_shape

    def compute_inline(self, x):
        rho = self.rho
        axes = self.axes

        if axes is None:
            x_max = np.max(x)
            smooth_max = x_max + 1/rho * np.log(np.sum(np.exp(rho * (x - x_max))))
            return smooth_max
        else:
            ones_shape = self.ones_shape
            axeswise_max = np.max(x, axis=self.axes)
            # print(self.einsum_str, axeswise_max.shape, ones_shape)
            difference = x - np.einsum(
                self.einsum_str,
                axeswise_max,
                np.ones(ones_shape),
                )
            exp = np.exp(rho * difference)
            summation = np.sum(exp, axis=axes)
            smooth_axeswise_max = axeswise_max + 1.0 / rho * np.log(summation)
            return smooth_axeswise_max

    def compute_jax(self, x):
        import jax.numpy as jnp
        rho = jnp.array(self.rho)
        axes = self.axes

        if axes is None:
            x_max = jnp.max(x)
            smooth_max = x_max + 1/rho * jnp.log(jnp.sum(jnp.exp(rho * (x - x_max))))
            return smooth_max
        else:
            ones_shape = self.ones_shape
            axeswise_max = jnp.max(x, axis=self.axes)
            # print(self.einsum_str, axeswise_max.shape, ones_shape)
            difference = x - jnp.einsum(
                self.einsum_str,
                axeswise_max,
                jnp.ones(ones_shape),
                )
            exp = jnp.exp(rho * difference)
            summation = jnp.sum(exp, axis=axes)
            smooth_axeswise_max = axeswise_max + 1.0 / rho * jnp.log(summation)
            return smooth_axeswise_max


    def evaluate_vjp(self, cotangents, x, y):
        if cotangents.check(x):
            if self.axes is None:
                rho = self.rho
                diff = x - y
                exp_x = csdl.exp(rho*diff)
                vjp = cotangents[y] * exp_x / csdl.sum(exp_x)
            else:
                rho = self.rho
                axes = self.axes

                in_str, out_str  = self.einsum_str.split('->')
                in_str, ones_str = in_str.split(',')
                exp_str = in_str + '->' + out_str
                exp_term = csdl.exp(rho*(x-csdl.expand(y, x.shape, exp_str)))
                sum = csdl.sum(exp_term, axes=axes)
                expanded_sum = csdl.expand(sum, out_shape=x.shape, action=exp_str)
                vjp = csdl.expand(cotangents[y] , x.shape, exp_str) * exp_term / expanded_sum

            cotangents.accumulate(x, vjp)

class ElementwiseMaximum(Operation):
    '''
    Elementwise maximum of all the Variables in the arguments.
    '''
    def __init__(self, *args, rho=20.):
        super().__init__(*args)
        self.name = 'elementwise_maximum'
        out_shapes = (args[0].shape,)
        self.rho   = rho
        self.set_dense_outputs(out_shapes)

    def compute_inline(self, *args):
        rho = self.rho
        ew_max = args[0]
        for arg in args[1:]:
            ew_max = np.maximum(ew_max, arg)

        summation = 0.
        for arg in args:
            summation += np.exp(rho * (arg - ew_max))

        smooth_ew_max = (ew_max + 1. / rho * np.log(summation))
        return smooth_ew_max

    def compute_jax(self, *args):
        import jax.numpy as jnp
        rho = jnp.array(self.rho)
        ew_max = args[0]
        for arg in args[1:]:
            ew_max = jnp.maximum(ew_max, arg)

        summation = 0.
        for arg in args:
            summation += jnp.exp(rho * (arg - ew_max))

        smooth_ew_max = (ew_max + 1. / rho * jnp.log(summation))
        return smooth_ew_max

    def evaluate_vjp(self, cotangents, *inputs_and_outputs):
        inputs = inputs_and_outputs[:len(self.inputs)]
        output = inputs_and_outputs[-1]
        rho = self.rho
        # sum  = cotangents[output]/csdl.sum(*[csdl.exp(rho*(arg-output)) for arg in inputs])
        for input_var in inputs:
            if cotangents.check(input_var):
                cotangents.accumulate(input_var, cotangents[output]/csdl.sum(*[csdl.exp(rho*(arg-input_var)) for arg in inputs]))

                # cotangents.accumulate(input_var, csdl.exp(rho*(input_var-output))*sum)

[docs]def maximum(*args, axes=None, rho=20.): ''' Computes the maximum entry in the input tensor if a single argument is provided. Computes the maximum entries along the specified axes if `axes` argument is given. Computes the elementwise maximum of multiple variables of the same shape, if multiple arguments are provided. Axes argument is not allowed in this case. Parameters ---------- *args : tuple of Variable or np.ndarray objects Input tensor/s whose maximum needs to be computed. axes : tuple of int, default=None Axes along which to compute the maximum of the input tensor, if there's only one input tensor. rho : float, default=20. Smoothing parameter for the maximum function. Returns ------- Variable Maximum entry in the input tensor if a single argument is provided. Maximum entries along the specified axes if `axes` argument is given. Elementwise maximum of multiple variables of the same shape, if multiple arguments are provided. Examples -------- >>> recorder = csdl.Recorder(inline = True) >>> recorder.start() >>> x_val = np.arange(6).reshape(2,3) >>> x = csdl.Variable(value = x_val) >>> y1 = csdl.maximum(x) >>> y1.value array([5.]) Maximum of a single tensor variable along a specified axis >>> y2 = csdl.maximum(x, axes=(1,)) >>> y2.value array([2., 5.]) Elementwise maximum of multiple tensor variables >>> y3 = csdl.maximum(x, 2 * np.ones((2,3)), np.ones((2,3))) >>> y3.value array([[2. , 2. , 2.03465736], [3. , 4. , 5. ]]) Note that `y3.value[0,2]` is not exactly `2.0` due to the smoothing term. It can be made closer to `2.0` by increasing the value of the smoothing parameter rho as shown below. >>> y = csdl.maximum(x, 2 * np.ones((2,3)), np.ones((2,3)), rho=200) >>> y.value array([[2. , 2. , 2.00346574], [3. , 4. , 5. ]]) ''' # Multiple Variables to find maximum if axes is not None and len(args) > 1: raise ValueError('Cannot find maximum of multiple Variables along specified axes. \ Use X = max(A,B,...) followed by out=max(X, axes=(...)) instead.') if any(args[i].shape != args[0].shape for i in range(1, len(args))): raise ValueError('All Variables must have the same shape.') # Single Variable to find maximum if axes is not None: if any(np.asarray(axes) > (len(args[0].shape)-1)): raise ValueError('Specified axes cannot be more than the rank of the Variable.') if any(np.asarray(axes) < 0): raise ValueError('Axes cannot have negative entries.') if len(args) == 1: if axes is None: out_shape = (1,) else: out_shape = tuple([x for i, x in enumerate(args[0].shape) if i not in axes]) if len(out_shape) == 0: # raise ValueError('It is inefficient to find the maximum of a tensor Variable along all axes. \ # Use maximum(A) to find the maximum of all tensor entries.') out_shape = (1,) axes = None op = Maximum(validate_and_variablize(args[0]), axes=axes, out_shape=out_shape, rho=rho) else: # axes is None for multiple variables args = [validate_and_variablize(x, raise_on_sparse=False) for x in args] op = ElementwiseMaximum(*args, rho=rho) return op.finalize_and_return_outputs()
class TestMaximum(csdl_tests.CSDLTest): def test_functionality(self,): self.prep() import csdl_alpha as csdl import numpy as np recorder = csdl.build_new_recorder(inline = True) recorder.start() x_val = 3.0*np.arange(6).reshape(2,3) y_val = 2.0*np.ones((2,3)) z_val = np.ones((2,3)) d_val = np.arange(12).reshape(2,3,2) x = csdl.Variable(name = 'x', value = x_val) y = csdl.Variable(name = 'y', value = y_val) z = csdl.Variable(name = 'z', value = z_val) d = csdl.Variable(name = 'd', value = d_val) compare_values = [] # maximum of a single tensor variable s1 = csdl.maximum(x) s1.add_name('s1') t1 = np.array([15.0]) compare_values += [csdl_tests.TestingPair(s1, t1, tag = 's1')] # maximum of a single tensor variable s1 = csdl.maximum(x, axes=(0,1)) compare_values += [csdl_tests.TestingPair(s1, t1, tag = 's1')] # maximum of a single tensor constant s2 = csdl.maximum(x_val) s2.add_name('s2') compare_values += [csdl_tests.TestingPair(s2, t1, tag = 's2')] # maximum of a single tensor variable along specified axes s3 = csdl.maximum(x, axes=(1,)) t3 = np.max(x_val, axis=1) s3.add_name('s3') compare_values += [csdl_tests.TestingPair(s3, t3, tag = 's3')] # maximum of a single tensor variable along 2 specified axes s4 = csdl.maximum(d, axes=(0,2)) t4 = np.max(d_val, axis=(0,2)) s4.add_name('s4') compare_values += [csdl_tests.TestingPair(s4, t4, tag = 's4', decimal=8)] # elementwise maximum of multiple tensor variables s5 = csdl.maximum(x, y, z) t5 = np.maximum(x_val, y_val) # s5.add_name('s5') compare_values += [csdl_tests.TestingPair(s5, t5, tag = 's5', decimal=8)] # elementwise maximum of multiple tensor constants s6 = csdl.maximum(x_val, y_val, z_val) # s6.add_name('s6') compare_values += [csdl_tests.TestingPair(s6, t5, tag = 's6', decimal=8)] # maximum of a single tensor constant # compare_values = [] n7 = np.array([10000.0,-10000.0]) s7 = csdl.maximum(n7) s7.add_name('s7') t7 = np.max(n7).flatten() compare_values += [csdl_tests.TestingPair(s7, t7, tag = 's7')] # TODO: maximum of a zero tensor - need to check this # to avoid errors from sum(log(1+1+..)) if there are multiple entries of zero # and zero is the maximum # zeros = np.zeros((2,3)) # s7 = csdl.maximum(zeros, rho=2000) # t7 = np.array([0.0]) # s7.add_name('s7') # compare_values += [csdl_tests.TestingPair(s7, t7, tag = 's7', decimal=3)] self.run_tests(compare_values = compare_values, verify_derivatives=True) def test_example(self,): self.docstest(maximum) if __name__ == '__main__': test = TestMaximum() test.overwrite_backend = 'jax' test.test_functionality() test.test_example()