Axon 的你好世界

我使用 LiveBook 对 Axon 进行了各种尝试,但都没有成功,我也遇到了困难。

我尝试拟合最简单的 y = ax + b。

加工流程

创建 100 个 {x,y} 对 y = 2 * x + 0.5 + 随机数

拟合 train_model()

显示结果

在 LiveBook 中执行的代码

设置

Mix.install([
  {:nx, "~> 0.3.0"},
  {:kino_vega_lite, "~> 0.1.3"},
  {:axon, "~> 0.2"},
])

身体

defmodule Basic do
  require Axon

  def build_model(input_shape) do
    inp1 = Axon.input("x", shape: input_shape)

    inp1
    |> Axon.dense(1)
  end

  defp batch do
    x = Nx.tensor(for _ <-1..100, do: [:rand.uniform()])
    y = Nx.multiply(x, 2)
     |> Nx.add(0.5)
     |> Nx.add(Nx.tensor(for _ <-1..100, do: [:rand.uniform()])|>Nx.multiply(0.5))
    {x, y}
  end

  defp train_model(model, data, epochs) do
    model
    |> Axon.Loop.trainer(:mean_squared_error, :sgd)
    |> Axon.Loop.run(data, %{}, epochs: epochs, iterations: 1000)
  end

  def run() do
    model = build_model({nil})
    data100 = batch()
    data = Stream.repeatedly(fn -> data100 end)
    # data = Stream.repeatedly(&batch/0)
    model_state = train_model(model, data, 10) |> IO.inspect(label: "model_state")

    result = Axon.predict(model, model_state, %{"x" => Nx.tensor([[0, 1]])|>Nx.transpose()   })
    {data100, result}
  end
end

显示结果

{{x,y},predict_y} = Basic.run()

VegaLite.new(width: 600, height: 600)
|> VegaLite.layers([
  VegaLite.new()
  |> VegaLite.data_from_values(x: x |> Nx.to_flat_list(), y: y |> Nx.to_flat_list())
  |> VegaLite.mark(:point, tooltip: true)
  |> VegaLite.encode_field(:x, "x", type: :quantitative)
  |> VegaLite.encode_field(:y, "y", type: :quantitative),
  VegaLite.new()
  |> VegaLite.data_from_values(x: [0, 1], y: Nx.to_flat_list(predict_y))
  |> VegaLite.mark(:line)
  |> VegaLite.encode_field(:x, "x", type: :quantitative)
  |> VegaLite.encode_field(:y, "y", type: :quantitative)
])

执行结果

Axonを使って線形回帰のパラメータ(傾きと切片)の値を求めてみよう

model_state 的值

model_state: %{
  "dense_0" => %{
    "bias" => #Nx.Tensor<
      f32[1]
      [0.6768180131912231]
    >,
    "kernel" => #Nx.Tensor<
      f32[1][1]
      [
        [2.0908796787261963]
      ]
    >
  }
}

由于随机数变化,截距有一些误差,但得到了几乎正确的值。

最后我紧紧抓住了入口。

参考
https://hexdocs.pm/axon/0.2.0/multi_input_example.html#everything-together


原创声明:本文系作者授权爱码网发表,未经许可,不得转载;

原文地址:https://www.likecs.com/show-308626920.html

相关文章: