【问题标题】:What is the recommended way to do embeddings in jax?在 jax 中进行嵌入的推荐方法是什么?
【发布时间】:2023-02-21 01:32:08
【问题描述】:

所以我的意思是你有一个分类特征 $X$(假设你已经把它变成了整数)并说你想使用特征 $A$ 将它嵌入到某个维度中,其中 $A$ 是 arity x n_embed。

这样做的通常方法是什么?使用 for 循环和 vmap 是否正确?我不想要像 jax.nn 这样更高效的东西

https://www.tensorflow.org/api_docs/python/tf/keras/layers/Embedding

例如考虑高 arity 和低嵌入 dim。

是这里的 flax.linen 实现中的 jnp.take 吗? https://github.com/google/flax/blob/main/flax/linen/linear.py#L624

【问题讨论】:

  • 你能澄清你对 using a for loop and vmap 的意思吗?
  • @GeoffreyNegiar 我的意思是不使用 jnp.take 你会从字面上遍历索引。但我现在认为 take 是正确的方法,看起来这就是使用 jax 的各种库在其实现中所做的事情。

标签: jax


【解决方案1】:

实际上,在纯 jax 中执行此操作的典型方法是使用 jnp.take。给定形状为(num_embeddings, num_features) 的嵌入的数组A 和形状为(n,) 的整数的分类特征x 然后下面为您提供嵌入查找。

jnp.take(A, x, axis=0)  # shape: (n, num_features)

如果使用 Flax,那么推荐的方法是使用 flax.linen.Embed 模块,并且会达到相同的效果:

import flax.linen as nn

class Model(nn.Module): 
  @nn.compact
  def __call__(self, x):
    emb = nn.Embed(num_embeddings, num_features)(x)  # shape

【讨论】:

  • A[x] 也有效
【解决方案2】:

假设 A 是嵌入表,x 是任意形状的索引。

  1. A[x],类似于jnp.take(A, x, axis=0),但更简单。
  2. vmap-ed A[x],它沿着x 的轴 0 平行。
  3. 嵌套vmap-ed A[x],沿x的所有轴平行。

    以下是供您参考的源代码。

    import jax
    import jax.numpy as jnp
    
    embs = jnp.array([[1, 1, 1], [2, 2, 2], [3, 3, 3], [4, 4, 4]], dtype=jnp.float32)
    
    x = jnp.array([[3, 1], [2, 0]], dtype=jnp.int32)
    
    print("
    take
    ", jnp.take(embs, x, axis=0))
    print("
    use []
    ", embs[x])
    print(
        "
    vmap
    ",
        jax.vmap(lambda embs, x: embs[x], in_axes=[None, 0], out_axes=0)(embs, x),
    )
    
    print(
        "
    nested vmap
    ",
        jax.vmap(
            jax.vmap(lambda embs, x: embs[x], in_axes=[None, 0], out_axes=0),
            in_axes=[None, 0],
            out_axes=0,
        )(embs, x),
    )
    

【讨论】:

    猜你喜欢
    • 1970-01-01
    • 1970-01-01
    • 2019-04-05
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 2014-03-13
    • 2010-10-19
    相关资源
    最近更新 更多