【发布时间】:2023-02-21 01:32:08
【问题描述】:
所以我的意思是你有一个分类特征 $X$(假设你已经把它变成了整数)并说你想使用特征 $A$ 将它嵌入到某个维度中,其中 $A$ 是 arity x n_embed。
这样做的通常方法是什么?使用 for 循环和 vmap 是否正确?我不想要像 jax.nn 这样更高效的东西
https://www.tensorflow.org/api_docs/python/tf/keras/layers/Embedding
例如考虑高 arity 和低嵌入 dim。
是这里的 flax.linen 实现中的 jnp.take 吗? https://github.com/google/flax/blob/main/flax/linen/linear.py#L624
【问题讨论】:
-
你能澄清你对
using a for loop and vmap的意思吗? -
@GeoffreyNegiar 我的意思是不使用 jnp.take 你会从字面上遍历索引。但我现在认为 take 是正确的方法,看起来这就是使用 jax 的各种库在其实现中所做的事情。
标签: jax