Source code for csdl_alpha.src.operations.power

from csdl_alpha.src.operations.operation_subclasses import ElementwiseOperation
from csdl_alpha.src.graph.operation import Operation, set_properties 
import numpy as np
from csdl_alpha.utils.inputs import variablize, validate_and_variablize
import csdl_alpha.utils.testing_utils as csdl_tests
from csdl_alpha.utils.typing import VariableLike
from csdl_alpha.src.graph.variable import Variable

class Power(ElementwiseOperation):
    '''
    Elementwise power of a tensor.
    Power of the first input to the second input. 
    '''

    def __init__(self,x,y):
        super().__init__(x,y)
        self.name = 'power'

    def compute_inline(self, x, y):
        return x ** y
    
    def compute_jax(self, x, y):
        import jax.numpy as jnp
        return (x ** y)
    
    def evaluate_vjp(self, cotangents, x, y, z):
        if cotangents.check(x):
            cotangents.accumulate(x, cotangents[z]*y*x**(y-1))
        if cotangents.check(y):
            import csdl_alpha as csdl
            cotangents.accumulate(y, cotangents[z]*z*csdl.log(x))

    def compute_jax(self, x, y):
        import jax.numpy as jnp
        return (x ** y)

class LeftBroadcastPower(Operation):
    '''
    First input is broadcasted to the shape of the second input.
    '''

    def __init__(self,x,y):
        super().__init__(x,y)
        self.name = 'left_broadcast_power'
        out_shapes = (y.shape,)
        self.set_dense_outputs(out_shapes)

    def compute_inline(self, x, y):
        return x ** y
    
    def compute_jax(self, x, y):
        import jax.numpy as jnp
        return (x ** y)
    
    def evaluate_vjp(self, cotangents, x, y, z):
        if cotangents.check(x):
            import csdl_alpha as csdl
            cotangents.accumulate(x, csdl.sum(cotangents[z]*y*x**(y-1)))
        if cotangents.check(y):
            import csdl_alpha as csdl
            cotangents.accumulate(y, cotangents[z]*z*csdl.log(x))

class RightBroadcastPower(Operation):
    '''
    Second input is broadcasted to the shape of the first input.
    '''

    def __init__(self,x,y):
        super().__init__(x,y)
        self.name = 'right_broadcast_power'
        out_shapes = (x.shape,)
        self.set_dense_outputs(out_shapes)

    def compute_inline(self, x, y):
        return x ** y

    def compute_jax(self, x, y):
        import jax.numpy as jnp
        return (x ** y)

    def evaluate_vjp(self, cotangents, x, y, z):
        if cotangents.check(x):
            import csdl_alpha as csdl
            cotangents.accumulate(x, cotangents[z]*y*x**(y-1))
        if cotangents.check(y):
            import csdl_alpha as csdl
            cotangents.accumulate(y, csdl.sum(cotangents[z]*z*csdl.log(x)))

[docs]def power(x:VariableLike, y:VariableLike) -> Variable: ''' Computes the power of the first input with exponent as the second input. If one of the inputs is a scalar, it is broadcasted to the shape of the other input. Parameters ---------- x : Variable, np.ndarray, float or int Input tensor whose power needs to be computed. y : Variable, np.ndarray, float or int Power to which the first input tensor needs to be raised. Returns ------- Variable Power of the first input with exponent as the second input. Examples -------- >>> recorder = csdl.Recorder(inline = True) >>> recorder.start() >>> x = csdl.Variable(value = np.array([1.0, 2.0, 3.0])) >>> y1 = csdl.power(x, 2) >>> y1.value array([1., 4., 9.]) >>> y2 = x ** 2 >>> y2.value array([1., 4., 9.]) Power raised to a tensor variable exponent >>> z = csdl.Variable(value = 3.0 * np.ones((3,))) >>> y2 = x ** z >>> y2.value array([ 1., 8., 27.]) ''' x = validate_and_variablize(x) y = validate_and_variablize(y) if x.shape == y.shape: op = Power(x, y) elif x.shape == (1,): op = LeftBroadcastPower(x, y) elif y.shape == (1,): op = RightBroadcastPower(x, y) else: raise ValueError('Shapes not compatible for the power operation.') return op.finalize_and_return_outputs()
class TestPower(csdl_tests.CSDLTest): def test_functionality(self,): self.prep() import csdl_alpha as csdl import numpy as np x_val = np.arange(6).reshape(2,3)+1.1 y_val = 2.0 z_val = 2.0*np.ones((2,3)) x = csdl.Variable(name = 'x', value = x_val) y = csdl.Variable(name = 'y', value = y_val) z = csdl.Variable(name = 'z', value = z_val) compare_values = [] # power of a tensor variable to a tensor variable y_tensor = csdl.Variable(value = x_val+1.0) s1 = csdl.power(x, y_tensor) compare_values += [csdl_tests.TestingPair(s1, x_val**(x_val+1.0), tag = 's0', decimal=9)] # power of a tensor variable to a tensor variable y_tensor = csdl.Variable(value = -x_val) s1 = csdl.power(x, y_tensor) compare_values += [csdl_tests.TestingPair(s1, x_val**(-x_val), tag = 's0', decimal=9)] # If x is negative, things get strange, y must be integers # slight difference (1e-10 error vs 1e-11) between and JAX and numpy here for some reason y_tensor = csdl.Variable(name = 'int_tensor', value = np.arange(6).reshape(2,3)+3.0) s1 = csdl.power(-x, y_tensor) compare_values += [csdl_tests.TestingPair(s1, (-x_val)**(np.arange(6).reshape(2,3)+3.0), tag = 's0', decimal=9)] # power of a scalar variable to a tensor variable x_scalar = csdl.Variable(value = 3.0) y_tensor = csdl.Variable(value = x_val+1.0) s1 = csdl.power(x_scalar, y_tensor) compare_values += [csdl_tests.TestingPair(s1, 3.0**(x_val+1.0), tag = 's0')] # power of a scalar variable to a tensor variable y_tensor = csdl.Variable(value = -x_val) s1 = csdl.power(x_scalar, y_tensor) compare_values += [csdl_tests.TestingPair(s1, 3.0**(-x_val), tag = 's0')] # If x is negative, things get strange, y must be integers y_tensor = csdl.Variable(name = 'int_tensor', value = np.arange(6).reshape(2,3)+3.0) s1 = csdl.power(-x_scalar, y_tensor) compare_values += [csdl_tests.TestingPair(s1, (-3.0)**(np.arange(6).reshape(2,3)+3.0), tag = 's0')] # power of a tensor variable to a scalar variable s1 = csdl.power(x, y) compare_values += [csdl_tests.TestingPair(s1, x_val**(y_val), tag = 's0')] # power of a tensor variable to a tensor variable s1 = csdl.power(x, -y) compare_values += [csdl_tests.TestingPair(s1, x_val**(-y_val), tag = 's0')] # If x is negative, things get strange, y must be integers s1 = csdl.power(-x, y) compare_values += [csdl_tests.TestingPair(s1, (-x_val)**(y_val), tag = 's0')] # power of a tensor variable to a scalar variable s1 = csdl.power(x, y) t1 = x_val ** y_val compare_values += [csdl_tests.TestingPair(s1, t1, tag = 's1')] # power of a tensor variable to a tensor constant s2 = csdl.power(x, z_val) compare_values += [csdl_tests.TestingPair(s2, t1, tag = 's2')] # power of a scalar constant to a tensor variable s3 = csdl.power(3.0, x) t3 = 3.0 ** x_val compare_values += [csdl_tests.TestingPair(s3, t3, tag = 's3')] self.run_tests(compare_values = compare_values, verify_derivatives=True) def test_example(self,): self.docstest(power) if __name__ == '__main__': test = TestPower() test.overwrite_backend = 'inline' test.overwrite_backend = 'jax' test.test_functionality() # test.test_example()