Axon 的你好世界
我使用 LiveBook 对 Axon 进行了各种尝试,但都没有成功,我也遇到了困难。
我尝试拟合最简单的 y = ax + b。
加工流程
创建 100 个 {x,y} 对 y = 2 * x + 0.5 + 随机数
↓
拟合 train_model()
↓
显示结果
在 LiveBook 中执行的代码
设置
Mix.install([
{:nx, "~> 0.3.0"},
{:kino_vega_lite, "~> 0.1.3"},
{:axon, "~> 0.2"},
])
身体
defmodule Basic do
require Axon
def build_model(input_shape) do
inp1 = Axon.input("x", shape: input_shape)
inp1
|> Axon.dense(1)
end
defp batch do
x = Nx.tensor(for _ <-1..100, do: [:rand.uniform()])
y = Nx.multiply(x, 2)
|> Nx.add(0.5)
|> Nx.add(Nx.tensor(for _ <-1..100, do: [:rand.uniform()])|>Nx.multiply(0.5))
{x, y}
end
defp train_model(model, data, epochs) do
model
|> Axon.Loop.trainer(:mean_squared_error, :sgd)
|> Axon.Loop.run(data, %{}, epochs: epochs, iterations: 1000)
end
def run() do
model = build_model({nil})
data100 = batch()
data = Stream.repeatedly(fn -> data100 end)
# data = Stream.repeatedly(&batch/0)
model_state = train_model(model, data, 10) |> IO.inspect(label: "model_state")
result = Axon.predict(model, model_state, %{"x" => Nx.tensor([[0, 1]])|>Nx.transpose() })
{data100, result}
end
end
显示结果
{{x,y},predict_y} = Basic.run()
VegaLite.new(width: 600, height: 600)
|> VegaLite.layers([
VegaLite.new()
|> VegaLite.data_from_values(x: x |> Nx.to_flat_list(), y: y |> Nx.to_flat_list())
|> VegaLite.mark(:point, tooltip: true)
|> VegaLite.encode_field(:x, "x", type: :quantitative)
|> VegaLite.encode_field(:y, "y", type: :quantitative),
VegaLite.new()
|> VegaLite.data_from_values(x: [0, 1], y: Nx.to_flat_list(predict_y))
|> VegaLite.mark(:line)
|> VegaLite.encode_field(:x, "x", type: :quantitative)
|> VegaLite.encode_field(:y, "y", type: :quantitative)
])
执行结果
model_state 的值
model_state: %{
"dense_0" => %{
"bias" => #Nx.Tensor<
f32[1]
[0.6768180131912231]
>,
"kernel" => #Nx.Tensor<
f32[1][1]
[
[2.0908796787261963]
]
>
}
}
由于随机数变化,截距有一些误差,但得到了几乎正确的值。
最后我紧紧抓住了入口。
参考
https://hexdocs.pm/axon/0.2.0/multi_input_example.html#everything-together
原创声明:本文系作者授权爱码网发表,未经许可,不得转载;
原文地址:https://www.likecs.com/show-308626920.html