{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Jax\n", "\n", "CSDL offers an experimental interface with [jax](https://github.com/google/jax),a powerful high performance numerical computing library. See the official [documentation](https://jax.readthedocs.io/en/latest/) to gain a better understanding of jax and its capabilities.\n", "\n", "\n", "We leverage jax's just-in-time compilation feature to efficiently evaluate CSDL models by using the experimental `JaxSimulator` class. We first define CSDL operations like normal:\n" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import csdl_alpha as csdl\n", "import numpy as np\n", "\n", "recorder = csdl.Recorder()\n", "recorder.start()\n", "\n", "# Write rosenbrock function\n", "size = 5\n", "x1 = csdl.Variable(name = \"x1\", value = np.arange(size)/size)\n", "x2 = csdl.Variable(name = \"x2\", value = np.arange(size)/size+1.0)\n", "f = (1 - x1)**2 + 100 * (x2 - x1**2)**2\n", "f.add_name(\"f\")\n", "\n", "recorder.stop()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Instantiate the JaxSimulator object and specify the inputs and outputs of the model. Note that any design variables, objectives and constraints are automatically set as inputs/outputs to the model. " ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "jax_sim = csdl.experimental.JaxSimulator(\n", " recorder = recorder,\n", " additional_inputs = [x1, x2],\n", " additional_outputs = f,\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We can then use the `run` and `compute_totals` method to evaluate the model and compute its derivatives respectively. The derivative computation will compute the derivatives of all outputs including any objectives, constraints and additional outputs (above) with respect to any design variables, and additional inputs (above)." ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "compiling 'run' function ...\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "2024-06-25 14:09:44.288720: W external/xla/xla/service/gpu/nvptx_compiler.cc:763] The NVIDIA driver's CUDA version is 12.2 which is older than the ptxas CUDA version (12.5.40). Because the driver is older than the ptxas version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "f:\n", " [101. 135.2 154.12 153.92 134.6 ] \n", "\n", "compiling 'compute_totals' function ...\n", "df_dx1:\n", " [[ -2. -0. -0. -0. -0. ]\n", " [ -0. -94.4 -0. -0. -0. ]\n", " [ -0. -0. -199.6 -0. -0. ]\n", " [ -0. -0. -0. -298.4 -0. ]\n", " [ -0. -0. -0. -0. -371.6]]\n", "df_dx2:\n", " [[200. 0. 0. 0. 0.]\n", " [ 0. 232. 0. 0. 0.]\n", " [ 0. 0. 248. 0. 0.]\n", " [ 0. 0. 0. 248. 0.]\n", " [ 0. 0. 0. 0. 232.]]\n" ] } ], "source": [ "jax_sim.run()\n", "print('f:\\n', jax_sim[f], '\\n')\n", "\n", "derivatives = jax_sim.compute_totals()\n", "print('df_dx1:\\n', derivatives[f,x1])\n", "print('df_dx2:\\n', derivatives[f,x2])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Verify the derivatives using finite difference via the `check_totals` method." ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "Derivative Verification Results\n", "-------------------------------\n", "ofs (1) wrts (2) norm fd norm error rel error tags \n", "--------------------------------------------------------------------------------------------------------------------\n", "f x1 525.2472179840651 525.2473677413986 0.00036129314409274804 6.87853316897778e-07 (5,),(5,), \n", "f x2 520.2460956124514 520.2463185424366 0.00022356574331855008 4.297305629089498e-07 (5,),(5,), \n" ] } ], "source": [ "checks = jax_sim.check_totals()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Change values of inputs by using the `__setitem__` syntax on the simulator (like `sim[] = `)." ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "f:\n", " [ 0. 5.8 31.52 92.52 208. ] \n", "\n", "df_dx1:\n", " [[ -0. 0. 0. 0. 0. ]\n", " [ -0. 115.6 0. 0. 0. ]\n", " [ -0. 0. 314.4 0. 0. ]\n", " [ -0. 0. 0. 615.6 0. ]\n", " [ -0. 0. 0. 0. 1038.4]]\n", "df_dx2:\n", " [[ 0. -0. -0. -0. -0.]\n", " [ 0. -48. -0. -0. -0.]\n", " [ 0. -0. -112. -0. -0.]\n", " [ 0. -0. -0. -192. -0.]\n", " [ 0. -0. -0. -0. -288.]]\n" ] } ], "source": [ "# Modify the input values\n", "jax_sim[x1] = jax_sim[x1] + 1.0\n", "\n", "# Re-run the simulation to update output values\n", "jax_sim.run()\n", "print('f:\\n', jax_sim[f], '\\n')\n", "\n", "derivatives = jax_sim.compute_totals()\n", "print('df_dx1:\\n', derivatives[f,x1])\n", "print('df_dx2:\\n', derivatives[f,x2])" ] } ], "metadata": { "kernelspec": { "display_name": "csdl_dev", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.19" } }, "nbformat": 4, "nbformat_minor": 2 }