【问题标题】:Convert npz jax weights into keras h5 weights将 npz jax 权重转换为 keras h5 权重
【发布时间】:2021-02-15 19:09:55
【问题描述】:

有没有办法将 jax npz 预训练的权重转换为 kers/tf.keras h5 格式的权重?

在网上找不到任何东西。

谢谢

【问题讨论】:

    标签: tensorflow keras tf.keras jax


    【解决方案1】:

    npz 格式转换为h5 格式最直接的方法是将数据加载到内存中,然后重写。

    这是一个简单的例子:

    import jax.numpy as jnp
    from jax import random
    import h5py
    
    # Create some random weights
    key = random.PRNGKey(1701)
    weights = random.normal(key, shape=(100,))
    
    # Save to an npz file
    jnp.savez('weights.npz', weights=weights)
    
    # Load the npz and convert to h5
    data = jnp.load('weights.npz')
    with h5py.File('weights.h5', 'w') as hf:
        hf.create_dataset('weights', data=data['weights'])
    

    请注意,这将取决于 npz 文件的内容和生成的 h5 文件的所需结构。

    【讨论】:

      猜你喜欢
      • 1970-01-01
      • 2021-10-17
      • 1970-01-01
      • 1970-01-01
      • 2018-07-13
      • 2020-05-24
      • 2018-02-15
      • 1970-01-01
      • 2019-08-04
      相关资源
      最近更新 更多