【问题标题】:JAX - Problem in differentiating of functionJAX - 区分功能的问题
【发布时间】:2021-08-01 10:38:56
【问题描述】:

我正在尝试在调用中执行蒙特卡罗模拟,然后在 Python 中计算其与基础资产相关的一阶导数,但它仍然不起作用

from jax import random
from jax import jit, grad, vmap
import jax.numpy as jnp

xi = jnp.linspace(1,1.2,5)
def Simulation(xi):
    K,T,number_sim,sigma,r,q = 1.,1.,100,0.4,0,0
    S = jnp.broadcast_to(xi,(number_sim,len(xi))).T

    mean = -.5 * sigma * sigma * T
    volatility = sigma*jnp.sqrt(T)
    r_numb = random.PRNGKey(10)
    BM = mean + volatility * random.normal(r_numb, shape=(number_sim,))

    product = S*jnp.exp(BM)

    payoff = jnp.maximum(product-K,0)

    result = jnp.average(payoff, axis=1)*jnp.exp(-q*T)

    return result

first_derivative = vmap(grad(Simulation))(xi)

我不知道实现该算法的方式是否是使用“AD 方法”计算导数的最佳方式;这个算法是这样工作的:

  • S = 模拟一个包含所有底层证券的矩阵;对于每一行,我使用“xi = jnp.linspace”生成每个底层,并且在矩阵的每一行内,我有相同的值,次数等于“number_sim”

  • product = 生成 BM(包含正常数的向量)后,我需要将 BM 的每个元素(带 exp)与 S

所以这是对算法的简短解释,我非常感谢任何形式的建议或技巧来管理这个问题,并用 AD 方法计算导数! 提前致谢

【问题讨论】:

  • 您能否编辑您的问题以提供有关您的预期输出的更多详细信息?梯度是相对于标量计算的,并且您似乎想要相对于向量的梯度。你真的对雅可比或逐元素标量梯度感兴趣吗?
  • 嗨@jakevdp,感谢您的回答。我需要关于一组标的资产“xi”的看涨期权的一阶导数,所以我的输出应该是一个包含所有导数的向量。也许是的,我在标量方面错误地开发了代码。 .如果您需要更多信息,请告诉我! :)
  • 我编辑了标题和标签以再次提及 JAX。问一个特定于 JAX 的问题,得到一个特定于 JAX 的答案,然后才删除对 JAX 的提及,这似乎很奇怪。

标签: python numpy montecarlo jax


【解决方案1】:

您的函数似乎映射了一个向量Rᴺ→Rᴺ。在这种情况下,有两个导数概念是有意义的:元素导数(在 JAX 中,您可以通过组合 jax.vmapjax.grad 来计算它)。这将返回一个长度为 N 的导数向量,其中元素 i 包含第 i 个输出相对于 i 的导数第一个输入。

或者,您可以计算雅可比矩阵(使用jax.jacobian),它将返回一个形状[N, N] 矩阵,其中元素i,j 包含i第 em> 个输出相对于第 j 个输入。

您遇到的问题是您的函数是在假设向量输入的情况下编写的(您要求 xi 的长度),这意味着您对 jacobian 感兴趣,但您要求的是元素导数,这需要一个标量值函数。

因此,您有两种可能的方法来解决这个问题,具体取决于您对哪种导数感兴趣。如果您对 jacobian 感兴趣,可以使用编写的函数并使用 jax.jacobian 转换:

from jax import jacobian
print(jacobian(Simulation)(xi))
# [[0.6528027 0.        0.        0.        0.       ]
#  [0.        0.6819291 0.        0.        0.       ]
#  [0.        0.        0.7003516 0.        0.       ]
#  [0.        0.        0.        0.7181915 0.       ]
#  [0.        0.        0.        0.        0.7608434]]

或者,如果您对元素梯度感兴趣,您可以重写您的函数以与标量输入兼容,并像您在示例中所做的那样使用 grad 的 vmap。只需更改两行:

def Simulation_scalar(xi):
    K,T,number_sim,sigma,r,q = 1.,1.,100,0.4,0,0

    # S = jnp.broadcast_to(xi,(number_sim,len(xi))).T
    S = jnp.broadcast_to(xi,(number_sim,) + xi.shape).T

    mean = -.5 * sigma * sigma * T
    volatility = sigma*jnp.sqrt(T)
    r_numb = random.PRNGKey(10)
    BM = mean + volatility * random.normal(r_numb, shape=(number_sim,))

    product = S*jnp.exp(BM)

    payoff = jnp.maximum(product-K,0)

    # result = jnp.average(payoff, axis=1)*jnp.exp(-q*T)
    result = jnp.average(payoff, axis=-1)*jnp.exp(-q*T)

    return result

print(vmap(grad(Simulation_scalar))(xi))
# [0.6528027 0.6819291 0.7003516 0.7181915 0.7608434]

【讨论】:

  • 哇!非常感谢你!!!又是一个小问题(也许是愚蠢的):为什么我会得到以下结果:“[0 , 0 , 0 , 0 , 0]” 用这个命令计算的二阶导数? --> print(vmap(grad(grad(Simulation_scalar)))(xi))
  • 你的函数相对于它的输入是线性的,所以二阶导数为零。
  • 您对二阶导数有什么想法/技巧吗? @jakevdp
  • 如果您要问如何将模型从线性模型更改为非线性模型,这对于 StackOverflow 评论线程来说似乎过于开放了。
  • 当然。以后,请在有关 JAX 的问题中添加jax 标签,这样我很可能会看到它们。
猜你喜欢
  • 1970-01-01
  • 1970-01-01
  • 1970-01-01
  • 1970-01-01
  • 1970-01-01
  • 1970-01-01
  • 1970-01-01
  • 2011-04-28
  • 1970-01-01
相关资源
最近更新 更多