Source code for csdl_alpha.src.operations.sqrt

from csdl_alpha.src.graph.operation import Operation, set_properties
from csdl_alpha.src.graph.variable import Variable
from csdl_alpha.src.operations.operation_subclasses import ElementwiseOperation
from csdl_alpha.utils.inputs import variablize, validate_and_variablize
import csdl_alpha.utils.testing_utils as csdl_tests
import csdl_alpha as csdl
from typing import Union

import numpy as np

@set_properties(linear=False)
class Sqrt(ElementwiseOperation):
    """
    The elementwise square roots of the input tensor.
    """
    def __init__(self,x):
        super().__init__(x)
        self.name = 'sqrt'

    def compute_inline(self, x):
        return np.sqrt(x)
    
    def compute_jax(self, x):
        import jax.numpy as jnp
        return jnp.sqrt(x)

    def evaluate_vjp(self, cotangents, x, y):
        import csdl_alpha as csdl
        if cotangents.check(x):
            cotangents.accumulate(x, cotangents[y] / (2*y))

[docs]def sqrt(x:Union[Variable, np.ndarray]) -> Variable: """ The elementwise square roots of the input tensor. Parameters ---------- x : Variable, np.ndarray, float, or int Input tensor to take the square root of. Returns ------- Variable Elementwise square roots of the input tensor. Examples -------- >>> recorder = csdl.Recorder(inline = True) >>> recorder.start() >>> x = csdl.Variable(value = np.array([1.0, 2.0, 3.0])) >>> y = csdl.sqrt(x) >>> y.value array([1. , 1.41421356, 1.73205081]) """ op = Sqrt(validate_and_variablize(x)) return op.finalize_and_return_outputs()
class TestSqrt(csdl_tests.CSDLTest): def test_functionality(self,): self.prep() import csdl_alpha as csdl import numpy as np recorder = csdl.build_new_recorder(inline = True) recorder.start() x_val = 3.0*np.arange(6).reshape(2,3)+0.05 x = csdl.Variable(name = 'x', value = x_val) compare_values = [] # square root of a tensor variable s1 = csdl.sqrt(x) t1 = np.sqrt(x_val) compare_values += [csdl_tests.TestingPair(s1, t1, tag = 's1')] # square root of a tensor constant s2 = csdl.sqrt(x_val) compare_values += [csdl_tests.TestingPair(s2, t1, tag = 's2')] # square root of scalar variables s3 = csdl.sqrt(2.0) t3 = np.array([np.sqrt(2.0)]) compare_values += [csdl_tests.TestingPair(s3, t3, tag = 's3')] self.run_tests(compare_values = compare_values, verify_derivatives=True) def test_example(self,): self.docstest(sqrt) if __name__ == '__main__': test = TestSqrt() test.test_functionality() test.test_example()