Source code for csdl_alpha.src.operations.log

from csdl_alpha.src.operations.operation_subclasses import ElementwiseOperation
from csdl_alpha.src.graph.operation import Operation, set_properties 
import numpy as np
from csdl_alpha.src.operations import add
from csdl_alpha.utils.inputs import variablize, validate_and_variablize
import csdl_alpha.utils.testing_utils as csdl_tests

class Log(ElementwiseOperation):
    '''
    Elementwise logarithm of a tensor.
    '''

    def __init__(self,x,y):
        super().__init__(x,y)
        self.name = 'log'

    def compute_inline(self, x, y):
        return np.log(x) / np.log(y)
    
    def compute_jax(self, x, y):
        import jax.numpy as jnp
        return jnp.log(x) / jnp.log(y)
    
    def evaluate_vjp(self, cotangents, x, y, z):
        import csdl_alpha as csdl
        vout = cotangents[z]
        if cotangents.check(x):
            cotangents.accumulate(x, vout / (x * csdl.log(y)))
        if cotangents.check(y):
            cotangents.accumulate(y, -vout * csdl.log(x) / (y * (csdl.log(y))**2))

# We need a broadcast log even when the methods are exactly the same because Broadcast cannot inherit from ElementwiseOperation
# TODO: Avoid code duplication
class LeftBroadcastLog(Operation):
    '''
    Logarithm after the first input is broadcasted to the shape of the second input.
    '''

    def __init__(self,x,y):
        super().__init__(x,y)
        self.name = 'left_broadcast_log'
        out_shapes = (y.shape,)
        self.set_dense_outputs(out_shapes)

    def compute_inline(self, x, y):
        return np.log(x) / np.log(y)
    
    def compute_jax(self, x, y):
        import jax.numpy as jnp
        return jnp.log(x) / jnp.log(y)
    
    def evaluate_vjp(self, cotangents, x, y, z):
        import csdl_alpha as csdl
        vout = cotangents[z]
        if cotangents.check(x):
            cotangents.accumulate(x, csdl.sum(vout / (x * csdl.log(y))))
        if cotangents.check(y):
            cotangents.accumulate(y, -vout * csdl.log(x) / (y * (csdl.log(y))**2))

class RightBroadcastLog(Operation):
    '''
    Logarithm after the second input is broadcasted to the shape of the first input.
    '''

    def __init__(self,x,y):
        super().__init__(x,y)
        self.name = 'right_broadcast_log'
        out_shapes = (x.shape,)
        self.set_dense_outputs(out_shapes)

    def compute_inline(self, x, y):
        return np.log(x) / np.log(y)

    def compute_jax(self, x, y):
        import jax.numpy as jnp
        return jnp.log(x) / jnp.log(y)

    def evaluate_vjp(self, cotangents, x, y, z):
        import csdl_alpha as csdl
        vout = cotangents[z]
        if cotangents.check(x):
            cotangents.accumulate(x, vout / (x * csdl.log(y)))
        if cotangents.check(y):
            cotangents.accumulate(y, - csdl.sum(vout *csdl.log(x) / (y * (csdl.log(y))**2)))

[docs]def log(x, base=None): ''' Computes the natural logarithm of all entries in the input tensor if `base` argument is not provided. Otherwise, computes the logarithm of all entries in the input tensor with respect to the specified base. If one of the inputs is a scalar, it is broadcasted to the shape of the other input. Parameters ---------- x : Variable, np.ndarray, float or int Input tensor whose logarithm needs to be computed. base : Variable, np.ndarray, float or int, default=np.e Base of the logarithm. If not provided, natural logarithm is computed. Returns ------- Variable Logarithm of the first input with base as the second input. Examples -------- >>> recorder = csdl.Recorder(inline = True) >>> recorder.start() >>> x = csdl.Variable(value = np.array([1.0, 2.0, 3.0])) >>> y1 = csdl.log(x) >>> y1.value array([0. , 0.69314718, 1.09861229]) Logarithm with a specified base >>> y2 = csdl.log(x, 2) >>> y2.value array([0. , 1. , 1.5849625]) Logarithm with a specified tensor variable base >>> b = csdl.Variable(value = 2.0 * np.ones((3,))) >>> y3 = csdl.log(x, b) >>> y3.value array([0. , 1. , 1.5849625]) ''' x = validate_and_variablize(x) if base is None: y = validate_and_variablize(np.e) else: y = validate_and_variablize(base) if x.shape == y.shape: op = Log(x, y) elif x.size == 1: op = LeftBroadcastLog(x.flatten(), y) elif y.size == 1: op = RightBroadcastLog(x, y.flatten()) else: raise ValueError('Shapes not compatible for log operation.') return op.finalize_and_return_outputs()
class TestLog(csdl_tests.CSDLTest): def test_functionality(self,): self.prep() import csdl_alpha as csdl import numpy as np x_val = 3.0 y_val = 2.0 x = csdl.Variable(name = 'x', value = x_val) y = csdl.Variable(name = 'y', value = y_val) compare_values = [] # log of a scalar variable s1 = csdl.log(x) t1 = np.array([np.log(x_val)]) compare_values += [csdl_tests.TestingPair(s1, t1, tag = 's1')] # log of a scalar constant s2 = csdl.log(3.0) compare_values += [csdl_tests.TestingPair(s2, t1, tag = 's2')] # log of a scalar variable with scalar variable base s3 = csdl.log(x, y) t3 = np.array([np.log(x_val) / np.log(y_val)]) compare_values += [csdl_tests.TestingPair(s3, t3, tag = 's3')] # log of a scalar variable with scalar constant base s4 = csdl.log(x, 2.0) compare_values += [csdl_tests.TestingPair(s4, t3, tag = 's4')] # log of a scalar constant with scalar constant base s5 = csdl.log(3.0, 2.0) compare_values += [csdl_tests.TestingPair(s5, t3, tag = 's5')] z_val = 2.0*np.ones((3,2)) z = csdl.Variable(name = 'z', value = z_val) # log of a tensor variable with tensor constant base s6 = csdl.log(z, 3.0*np.ones((3,2))) t6 = np.log(z_val) / np.log(3.0) compare_values += [csdl_tests.TestingPair(s6, t6, tag = 's6')] # log of a tensor constant with a tensor variable base s7 = csdl.log(3.0*np.ones((3,2)), z) t6 = np.log(3.0) / np.log(z_val) compare_values += [csdl_tests.TestingPair(s7, t6, tag = 's7')] # log of a scalar constant with a tensor variable base s8 = csdl.log(3.0, z) compare_values += [csdl_tests.TestingPair(s8, t6, tag = 's8')] # log of a scalar constant with a tensor variable base s8 = csdl.log(0.001) t6 = np.log(0.001).flatten() compare_values += [csdl_tests.TestingPair(s8, t6, tag = 's8', decimal = 3)] self.run_tests(compare_values = compare_values, verify_derivatives=True) def test_example(self,): self.docstest(log) if __name__ == '__main__': test = TestLog() test.test_functionality() test.test_example()