6. Jax

CSDL offers an experimental interface with jax,a powerful high performance numerical computing library. See the official documentation to gain a better understanding of jax and its capabilities.

We leverage jax’s just-in-time compilation feature to efficiently evaluate CSDL models by using the experimental JaxSimulator class. We first define CSDL operations like normal:

import csdl_alpha as csdl
import numpy as np

recorder = csdl.Recorder()
recorder.start()

# Write rosenbrock function
size = 5
x1 = csdl.Variable(name = "x1", value = np.arange(size)/size)
x2 = csdl.Variable(name = "x2", value = np.arange(size)/size+1.0)
f = (1 - x1)**2 + 100 * (x2 - x1**2)**2
f.add_name("f")

recorder.stop()

Instantiate the JaxSimulator object and specify the inputs and outputs of the model. Note that any design variables, objectives and constraints are automatically set as inputs/outputs to the model.

jax_sim = csdl.experimental.JaxSimulator(
    recorder = recorder,
    additional_inputs = [x1, x2],
    additional_outputs = f,
)

We can then use the run and compute_totals method to evaluate the model and compute its derivatives respectively. The derivative computation will compute the derivatives of all outputs including any objectives, constraints and additional outputs (above) with respect to any design variables, and additional inputs (above).

jax_sim.run()
print('f:\n', jax_sim[f], '\n')

derivatives = jax_sim.compute_totals()
print('df_dx1:\n', derivatives[f,x1])
print('df_dx2:\n', derivatives[f,x2])
compiling 'run' function ...
2024-06-25 14:09:44.288720: W external/xla/xla/service/gpu/nvptx_compiler.cc:763] The NVIDIA driver's CUDA version is 12.2 which is older than the ptxas CUDA version (12.5.40). Because the driver is older than the ptxas version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.
f:
 [101.   135.2  154.12 153.92 134.6 ] 

compiling 'compute_totals' function ...
df_dx1:
 [[  -2.    -0.    -0.    -0.    -0. ]
 [  -0.   -94.4   -0.    -0.    -0. ]
 [  -0.    -0.  -199.6   -0.    -0. ]
 [  -0.    -0.    -0.  -298.4   -0. ]
 [  -0.    -0.    -0.    -0.  -371.6]]
df_dx2:
 [[200.   0.   0.   0.   0.]
 [  0. 232.   0.   0.   0.]
 [  0.   0. 248.   0.   0.]
 [  0.   0.   0. 248.   0.]
 [  0.   0.   0.   0. 232.]]

Verify the derivatives using finite difference via the check_totals method.

checks = jax_sim.check_totals()
Derivative Verification Results
-------------------------------
ofs (1)   wrts (2)   norm                fd norm             error                    rel error               tags         
--------------------------------------------------------------------------------------------------------------------
f         x1         525.2472179840651   525.2473677413986   0.00036129314409274804   6.87853316897778e-07    (5,),(5,),   
f         x2         520.2460956124514   520.2463185424366   0.00022356574331855008   4.297305629089498e-07   (5,),(5,),   

Change values of inputs by using the __setitem__ syntax on the simulator (like sim[<variable>] = <np.ndarray>).

# Modify the input values
jax_sim[x1] = jax_sim[x1] + 1.0

# Re-run the simulation to update output values
jax_sim.run()
print('f:\n', jax_sim[f], '\n')

derivatives = jax_sim.compute_totals()
print('df_dx1:\n', derivatives[f,x1])
print('df_dx2:\n', derivatives[f,x2])
f:
 [  0.     5.8   31.52  92.52 208.  ] 

df_dx1:
 [[  -0.     0.     0.     0.     0. ]
 [  -0.   115.6    0.     0.     0. ]
 [  -0.     0.   314.4    0.     0. ]
 [  -0.     0.     0.   615.6    0. ]
 [  -0.     0.     0.     0.  1038.4]]
df_dx2:
 [[   0.   -0.   -0.   -0.   -0.]
 [   0.  -48.   -0.   -0.   -0.]
 [   0.   -0. -112.   -0.   -0.]
 [   0.   -0.   -0. -192.   -0.]
 [   0.   -0.   -0.   -0. -288.]]