ME-NeuralFMUs using Growing Horizon
Tutorial by Johannes Stoljar, Tobias Thummerer
Last edit: 08.11.2023
LICENSE
# Copyright (c) 2021 Tobias Thummerer, Lars Mikelsons, Johannes Stoljar
# Licensed under the MIT license.
# See LICENSE (https://github.com/thummeto/FMIFlux.jl/blob/main/LICENSE) file in the project root for details.
Motivation
The Julia Package FMIFlux.jl is motivated by the application of hybrid modeling. This package enables the user to integrate his simulation model between neural networks (NeuralFMU). For this, the simulation model must be exported as FMU (functional mock-up unit), which corresponds to a widely used standard. The big advantage of hybrid modeling with artificial neural networks is, that effects that are difficult to model (because they might be unknown) can be easily learned by the neural networks. For this purpose, the NeuralFMU is trained with measurement data containing the not modeled physical effect. The final product is a simulation model including the originally not modeled effects. Another big advantage of the NeuralFMU is that it works with little data, because the FMU already contains the characteristic functionality of the simulation and only the missing effects are added.
NeuralFMUs do not need to be as easy as in this example. Basically a NeuralFMU can combine different ANN topologies that manipulate any FMU-input (system state, system inputs, time) and any FMU-output (system state derivative, system outputs, other system variables). However, for this example a NeuralFMU topology as shown in the following picture is used.
NeuralFMU (ME) from [1].
Introduction to the example
In this example, simplified modeling of a one-dimensional spring pendulum (without friction) is compared to a model of the same system that includes a nonlinear friction model. The FMU with the simplified model will be named simpleFMU in the following and the model with the friction will be named fricFMU. At the beginning, the actual state of both simulations is shown, whereby clear deviations can be seen in the graphs. The fricFMU serves as a reference graph. The simpleFMU is then integrated into a NeuralFMU architecture and a training of the entire network is performed. After the training the final state is compared again to the fircFMU. It can be clearly seen that by using the NeuralFMU, learning of the friction process has taken place.
Target group
The example is primarily intended for users who work in the field of first principle and/or hybrid modeling and are further interested in hybrid model building. The example wants to show how simple it is to combine FMUs with machine learning and to illustrate the advantages of this approach.
Other formats
Besides, this Jupyter Notebook there is also a Julia file with the same name, which contains only the code cells and for the documentation there is a Markdown file corresponding to the notebook.
Getting started
Installation prerequisites
Description | Command | |
---|---|---|
1. | Enter Package Manager via | ] |
2. | Install FMI via | add FMI |
3. | Install FMIFlux via | add FMIFlux |
4. | Install FMIZoo via | add FMIZoo |
5. | Install Plots via | add Plots |
6. | Install Random via | add Random |
Code section
To run the example, the previously installed packages must be included.
# imports
using FMI
using FMI.FMIImport: fmi2StringToValueReference, fmi2ValueReference, fmi2Real
using FMIFlux
using FMIFlux.Flux
using FMIZoo
using FMI.DifferentialEquations: Tsit5
using Statistics: mean, std
using Plots
# set seed
import Random
Random.seed!(1234);
[33m[1m┌ [22m[39m[33m[1mWarning: [22m[39mError requiring `Enzyme` from `LinearSolve`
[33m[1m│ [22m[39m exception =
[33m[1m│ [22m[39m LoadError: ArgumentError: Package LinearSolve does not have Enzyme in its dependencies:
[33m[1m│ [22m[39m - You may have a partially installed environment. Try `Pkg.instantiate()`
[33m[1m│ [22m[39m to ensure all packages in the environment are installed.
[33m[1m│ [22m[39m - Or, if you have LinearSolve checked out for development and have
[33m[1m│ [22m[39m added Enzyme as a dependency but haven't updated your primary
[33m[1m│ [22m[39m environment's manifest file, try `Pkg.resolve()`.
[33m[1m│ [22m[39m - Otherwise you may need to report an issue with LinearSolve
[33m[1m│ [22m[39m Stacktrace:
[33m[1m│ [22m[39m [1] [0m[1mmacro expansion[22m
[33m[1m│ [22m[39m [90m @ [39m[90m.\[39m[90m[4mloading.jl:1167[24m[39m[90m [inlined][39m
[33m[1m│ [22m[39m [2] [0m[1mmacro expansion[22m
[33m[1m│ [22m[39m [90m @ [39m[90m.\[39m[90m[4mlock.jl:223[24m[39m[90m [inlined][39m
[33m[1m│ [22m[39m [3] [0m[1mrequire[22m[0m[1m([22m[90minto[39m::[0mModule, [90mmod[39m::[0mSymbol[0m[1m)[22m
[33m[1m│ [22m[39m [90m @ [39m[90mBase[39m [90m.\[39m[90m[4mloading.jl:1144[24m[39m
[33m[1m│ [22m[39m [4] [0m[1minclude[22m[0m[1m([22m[90mmod[39m::[0mModule, [90m_path[39m::[0mString[0m[1m)[22m
[33m[1m│ [22m[39m [90m @ [39m[90mBase[39m [90m.\[39m[90m[4mBase.jl:419[24m[39m
[33m[1m│ [22m[39m [5] [0m[1minclude[22m[0m[1m([22m[90mx[39m::[0mString[0m[1m)[22m
[33m[1m│ [22m[39m [90m @ [39m[35mLinearSolve[39m [90mC:\Users\runneradmin\.julia\packages\LinearSolve\qCLK7\src\[39m[90m[4mLinearSolve.jl:1[24m[39m
[33m[1m│ [22m[39m [6] [0m[1mmacro expansion[22m
[33m[1m│ [22m[39m [90m @ [39m[90mC:\Users\runneradmin\.julia\packages\Requires\Z8rfN\src\[39m[90m[4mRequires.jl:40[24m[39m[90m [inlined][39m
[33m[1m│ [22m[39m [7] top-level scope
[33m[1m│ [22m[39m [90m @ [39m[90mC:\Users\runneradmin\.julia\packages\LinearSolve\qCLK7\src\[39m[90m[4minit.jl:16[24m[39m
[33m[1m│ [22m[39m [8] [0m[1meval[22m
[33m[1m│ [22m[39m [90m @ [39m[90m.\[39m[90m[4mboot.jl:368[24m[39m[90m [inlined][39m
[33m[1m│ [22m[39m [9] [0m[1meval[22m
[33m[1m│ [22m[39m [90m @ [39m[90mC:\Users\runneradmin\.julia\packages\LinearSolve\qCLK7\src\[39m[90m[4mLinearSolve.jl:1[24m[39m[90m [inlined][39m
[33m[1m│ [22m[39m [10] [0m[1m(::LinearSolve.var"#88#97")[22m[0m[1m([22m[0m[1m)[22m
[33m[1m│ [22m[39m [90m @ [39m[35mLinearSolve[39m [90mC:\Users\runneradmin\.julia\packages\Requires\Z8rfN\src\[39m[90m[4mrequire.jl:101[24m[39m
[33m[1m│ [22m[39m [11] [0m[1mmacro expansion[22m
[33m[1m│ [22m[39m [90m @ [39m[90m[4mtiming.jl:382[24m[39m[90m [inlined][39m
[33m[1m│ [22m[39m [12] [0m[1merr[22m[0m[1m([22m[90mf[39m::[0mAny, [90mlistener[39m::[0mModule, [90mmodname[39m::[0mString, [90mfile[39m::[0mString, [90mline[39m::[0mAny[0m[1m)[22m
[33m[1m│ [22m[39m [90m @ [39m[36mRequires[39m [90mC:\Users\runneradmin\.julia\packages\Requires\Z8rfN\src\[39m[90m[4mrequire.jl:47[24m[39m
[33m[1m│ [22m[39m [13] [0m[1m(::LinearSolve.var"#87#96")[22m[0m[1m([22m[0m[1m)[22m
[33m[1m│ [22m[39m [90m @ [39m[35mLinearSolve[39m [90mC:\Users\runneradmin\.julia\packages\Requires\Z8rfN\src\[39m[90m[4mrequire.jl:100[24m[39m
[33m[1m│ [22m[39m [14] [0m[1mwithpath[22m[0m[1m([22m[90mf[39m::[0mAny, [90mpath[39m::[0mString[0m[1m)[22m
[33m[1m│ [22m[39m [90m @ [39m[36mRequires[39m [90mC:\Users\runneradmin\.julia\packages\Requires\Z8rfN\src\[39m[90m[4mrequire.jl:37[24m[39m
[33m[1m│ [22m[39m [15] [0m[1m(::LinearSolve.var"#86#95")[22m[0m[1m([22m[0m[1m)[22m
[33m[1m│ [22m[39m [90m @ [39m[35mLinearSolve[39m [90mC:\Users\runneradmin\.julia\packages\Requires\Z8rfN\src\[39m[90m[4mrequire.jl:99[24m[39m
[33m[1m│ [22m[39m [16] [0m[1m#invokelatest#2[22m
[33m[1m│ [22m[39m [90m @ [39m[90m.\[39m[90m[4messentials.jl:729[24m[39m[90m [inlined][39m
[33m[1m│ [22m[39m [17] [0m[1minvokelatest[22m
[33m[1m│ [22m[39m [90m @ [39m[90m.\[39m[90m[4messentials.jl:726[24m[39m[90m [inlined][39m
[33m[1m│ [22m[39m [18] [0m[1mforeach[22m[0m[1m([22m[90mf[39m::[0mtypeof(Base.invokelatest), [90mitr[39m::[0mVector[90m{Function}[39m[0m[1m)[22m
[33m[1m│ [22m[39m [90m @ [39m[90mBase[39m [90m.\[39m[90m[4mabstractarray.jl:2774[24m[39m
[33m[1m│ [22m[39m [19] [0m[1mloadpkg[22m[0m[1m([22m[90mpkg[39m::[0mBase.PkgId[0m[1m)[22m
[33m[1m│ [22m[39m [90m @ [39m[36mRequires[39m [90mC:\Users\runneradmin\.julia\packages\Requires\Z8rfN\src\[39m[90m[4mrequire.jl:27[24m[39m
[33m[1m│ [22m[39m [20] [0m[1m#invokelatest#2[22m
[33m[1m│ [22m[39m [90m @ [39m[90m.\[39m[90m[4messentials.jl:729[24m[39m[90m [inlined][39m
[33m[1m│ [22m[39m [21] [0m[1minvokelatest[22m
[33m[1m│ [22m[39m [90m @ [39m[90m.\[39m[90m[4messentials.jl:726[24m[39m[90m [inlined][39m
[33m[1m│ [22m[39m [22] [0m[1mrun_package_callbacks[22m[0m[1m([22m[90mmodkey[39m::[0mBase.PkgId[0m[1m)[22m
[33m[1m│ [22m[39m [90m @ [39m[90mBase[39m [90m.\[39m[90m[4mloading.jl:869[24m[39m
[33m[1m│ [22m[39m [23] [0m[1m_tryrequire_from_serialized[22m[0m[1m([22m[90mmodkey[39m::[0mBase.PkgId, [90mpath[39m::[0mString, [90msourcepath[39m::[0mString, [90mdepmods[39m::[0mVector[90m{Any}[39m[0m[1m)[22m
[33m[1m│ [22m[39m [90m @ [39m[90mBase[39m [90m.\[39m[90m[4mloading.jl:944[24m[39m
[33m[1m│ [22m[39m [24] [0m[1m_require_search_from_serialized[22m[0m[1m([22m[90mpkg[39m::[0mBase.PkgId, [90msourcepath[39m::[0mString, [90mbuild_id[39m::[0mUInt64[0m[1m)[22m
[33m[1m│ [22m[39m [90m @ [39m[90mBase[39m [90m.\[39m[90m[4mloading.jl:1028[24m[39m
[33m[1m│ [22m[39m [25] [0m[1m_require[22m[0m[1m([22m[90mpkg[39m::[0mBase.PkgId[0m[1m)[22m
[33m[1m│ [22m[39m [90m @ [39m[90mBase[39m [90m.\[39m[90m[4mloading.jl:1315[24m[39m
[33m[1m│ [22m[39m [26] [0m[1m_require_prelocked[22m[0m[1m([22m[90muuidkey[39m::[0mBase.PkgId[0m[1m)[22m
[33m[1m│ [22m[39m [90m @ [39m[90mBase[39m [90m.\[39m[90m[4mloading.jl:1200[24m[39m
[33m[1m│ [22m[39m [27] [0m[1mmacro expansion[22m
[33m[1m│ [22m[39m [90m @ [39m[90m.\[39m[90m[4mloading.jl:1180[24m[39m[90m [inlined][39m
[33m[1m│ [22m[39m [28] [0m[1mmacro expansion[22m
[33m[1m│ [22m[39m [90m @ [39m[90m.\[39m[90m[4mlock.jl:223[24m[39m[90m [inlined][39m
[33m[1m│ [22m[39m [29] [0m[1mrequire[22m[0m[1m([22m[90minto[39m::[0mModule, [90mmod[39m::[0mSymbol[0m[1m)[22m
[33m[1m│ [22m[39m [90m @ [39m[90mBase[39m [90m.\[39m[90m[4mloading.jl:1144[24m[39m
[33m[1m│ [22m[39m [30] [0m[1meval[22m
[33m[1m│ [22m[39m [90m @ [39m[90m.\[39m[90m[4mboot.jl:368[24m[39m[90m [inlined][39m
[33m[1m│ [22m[39m [31] [0m[1minclude_string[22m[0m[1m([22m[90mmapexpr[39m::[0mtypeof(REPL.softscope), [90mmod[39m::[0mModule, [90mcode[39m::[0mString, [90mfilename[39m::[0mString[0m[1m)[22m
[33m[1m│ [22m[39m [90m @ [39m[90mBase[39m [90m.\[39m[90m[4mloading.jl:1428[24m[39m
[33m[1m│ [22m[39m [32] [0m[1msoftscope_include_string[22m[0m[1m([22m[90mm[39m::[0mModule, [90mcode[39m::[0mString, [90mfilename[39m::[0mString[0m[1m)[22m
[33m[1m│ [22m[39m [90m @ [39m[32mSoftGlobalScope[39m [90mC:\Users\runneradmin\.julia\packages\SoftGlobalScope\u4UzH\src\[39m[90m[4mSoftGlobalScope.jl:65[24m[39m
[33m[1m│ [22m[39m [33] [0m[1mexecute_request[22m[0m[1m([22m[90msocket[39m::[0mZMQ.Socket, [90mmsg[39m::[0mIJulia.Msg[0m[1m)[22m
[33m[1m│ [22m[39m [90m @ [39m[33mIJulia[39m [90mC:\Users\runneradmin\.julia\packages\IJulia\Vo51o\src\[39m[90m[4mexecute_request.jl:67[24m[39m
[33m[1m│ [22m[39m [34] [0m[1m#invokelatest#2[22m
[33m[1m│ [22m[39m [90m @ [39m[90m.\[39m[90m[4messentials.jl:729[24m[39m[90m [inlined][39m
[33m[1m│ [22m[39m [35] [0m[1minvokelatest[22m
[33m[1m│ [22m[39m [90m @ [39m[90m.\[39m[90m[4messentials.jl:726[24m[39m[90m [inlined][39m
[33m[1m│ [22m[39m [36] [0m[1meventloop[22m[0m[1m([22m[90msocket[39m::[0mZMQ.Socket[0m[1m)[22m
[33m[1m│ [22m[39m [90m @ [39m[33mIJulia[39m [90mC:\Users\runneradmin\.julia\packages\IJulia\Vo51o\src\[39m[90m[4meventloop.jl:8[24m[39m
[33m[1m│ [22m[39m [37] [0m[1m(::IJulia.var"#15#18")[22m[0m[1m([22m[0m[1m)[22m
[33m[1m│ [22m[39m [90m @ [39m[33mIJulia[39m [90m.\[39m[90m[4mtask.jl:484[24m[39m
[33m[1m│ [22m[39m in expression starting at C:\Users\runneradmin\.julia\packages\LinearSolve\qCLK7\ext\LinearSolveEnzymeExt.jl:1
[33m[1m└ [22m[39m[90m@ Requires C:\Users\runneradmin\.julia\packages\Requires\Z8rfN\src\require.jl:51[39m
After importing the packages, the path to the Functional Mock-up Units (FMUs) is set. The FMU is a model exported meeting the Functional Mock-up Interface (FMI) Standard. The FMI is a free standard (fmi-standard.org) that defines a container and an interface to exchange dynamic models using a combination of XML files, binaries and C code zipped into a single file.
The object-orientated structure of the SpringPendulum1D (simpleFMU) can be seen in the following graphic and corresponds to a simple modeling.
In contrast, the model SpringFrictionPendulum1D (fricFMU) is somewhat more accurate, because it includes a friction component.
Next, the start time and end time of the simulation are set. Finally, a step size is specified to store the results of the simulation at these time steps.
tStart = 0.0
tStep = 0.1
tStop = 5.0
tSave = collect(tStart:tStep:tStop)
51-element Vector{Float64}:
0.0
0.1
0.2
0.3
0.4
0.5
0.6
0.7
0.8
0.9
1.0
1.1
1.2
⋮
3.9
4.0
4.1
4.2
4.3
4.4
4.5
4.6
4.7
4.8
4.9
5.0
fricFMU
In the next lines of code the FMU of the fricFMU model from FMIZoo.jl is loaded and the information about the FMU is shown.
fricFMU = fmiLoad("SpringFrictionPendulum1D", "Dymola", "2022x")
fmiInfo(fricFMU)
#################### Begin information for FMU ####################
Model name: SpringFrictionPendulum1D
FMI-Version: 2.0
GUID: {2e178ad3-5e9b-48ec-a7b2-baa5669efc0c}
Generation tool: Dymola Version 2022x (64-bit), 2021-10-08
Generation time: 2022-05-19T06:54:12Z
Var. naming conv.: structured
Event indicators: 24
Inputs: 0
Outputs: 0
States: 2
33554432 ["mass.s"]
33554433 ["mass.v", "mass.v_relfric"]
Supports Co-Simulation: true
Model identifier: SpringFrictionPendulum1D
Get/Set State: true
Serialize State: true
Dir. Derivatives: true
Var. com. steps: true
Input interpol.: true
Max order out. der.: 1
Supports Model-Exchange: true
Model identifier: SpringFrictionPendulum1D
Get/Set State: true
Serialize State: true
Dir. Derivatives: true
##################### End information for FMU #####################
In the function fmiSimulate() the fricFMU is simulated, still specifying the start and end time, the parameters and which variables are recorded. After the simulation is finished the result of the fricFMU can be plotted. This plot also serves as a reference for the other model (simpleFMU).
vrs = ["mass.s", "mass.v", "mass.a", "mass.f"]
solFric = fmiSimulate(fricFMU, (tStart, tStop); recordValues=vrs, saveat=tSave)
plot(solFric)
[34mSimulating CS-FMU ... 0%|█ | ETA: N/A[39m
[34mSimulating CS-FMU ... 100%|██████████████████████████████| Time: 0:00:02[39m
The data from the simulation of the fricFMU, are divided into position and velocity data. These data will be needed later.
posFric = fmi2GetSolutionValue(solFric, "mass.s")
velFric = fmi2GetSolutionValue(solFric, "mass.v")
51-element Vector{Float64}:
0.0
0.432852398300982
0.8401743918610578
1.1702254881462497
1.3861768532456016
1.4649609400224617
1.397962181945595
1.1917483098990418
0.8657325133644009
0.44821918384886916
-0.02200493896693855
-0.380560845401747
-0.7172068753289351
⋮
-0.19353187721088116
0.021605187634145845
0.12911473439606144
0.2315130895115627
0.31667721272388255
0.37417576531479746
0.3964197153211615
0.3795927497483354
0.3235539803194403
0.2317738499958648
0.11061350893737848
-1.0008118292437196e-10
The FMU has two states: The first state is the position of the mass and the second state is the velocity. The initial position of the mass is initialized with $0.5𝑚$. The initial velocity of the mass is initialized with $0\frac{m}{s}$.
x₀ = [posFric[1], velFric[1]]
2-element Vector{Float64}:
0.5
0.0
After extracting the data, the FMU is cleaned-up.
fmiUnload(fricFMU)
SimpleFMU
The following lines load, simulate and plot the simpleFMU just like the fricFMU. The differences between both systems can be clearly seen from the plots. In the plot for the fricFMU it can be seen that the oscillation continues to decrease due to the effect of the friction. If you simulate long enough, the oscillation would come to a standstill in a certain time. The oscillation in the simpleFMU behaves differently, since the friction was not taken into account here. The oscillation in this model would continue to infinity with the same oscillation amplitude. From this observation the desire of an improvement of this model arises.
simpleFMU = fmiLoad("SpringPendulum1D", "Dymola", "2022x"; type=:ME)
fmiInfo(simpleFMU)
vrs = ["mass.s", "mass.v", "mass.a"]
solSimple = fmiSimulate(simpleFMU, (tStart, tStop); recordValues=vrs, saveat=tSave)
plot(solSimple)
#################### Begin information for FMU ####################
Model name: SpringPendulum1D
FMI-Version: 2.0
GUID: {fc15d8c4-758b-48e6-b00e-5bf47b8b14e5}
Generation tool: Dymola Version 2022x (64-bit), 2021-10-08
Generation time: 2022-05-19T06:54:23Z
Var. naming conv.: structured
Event indicators: 0
Inputs: 0
Outputs: 0
States: 2
33554432 ["mass.s"]
33554433 ["mass.v"]
Supports Co-Simulation: true
Model identifier: SpringPendulum1D
Get/Set State: true
Serialize State: true
Dir. Derivatives: true
Var. com. steps: true
Input interpol.: true
Max order out. der.: 1
Supports Model-Exchange: true
Model identifier: SpringPendulum1D
Get/Set State: true
Serialize State: true
Dir. Derivatives: true
##################### End information for FMU #####################
[34mSimulating ME-FMU ... 0%|█ | ETA: N/A[39m
[34mSimulating ME-FMU ... 100%|██████████████████████████████| Time: 0:00:09[39m
The data from the simulation of the simpleFMU, are divided into position and velocity data. These data will be needed later to plot the results.
posSimple = fmi2GetSolutionValue(solSimple, "mass.s")
velSimple = fmi2GetSolutionValue(solSimple, "mass.v")
51-element Vector{Float64}:
0.0
0.5900499045575546
1.1215848686403096
1.541892887990106
1.809292251922929
1.897265115600142
1.797087259817532
1.518693266896391
1.0896913136721704
0.5526252840128629
-0.03924428436005998
-0.6272220176030365
-1.1529984931243986
⋮
-0.4389974278206929
0.15680920523609004
0.7370651557739741
1.244226766096815
1.6279991034408674
1.850323680279787
1.889152693194497
1.7406354911553794
1.4195004423851794
0.9575943302840277
0.4007241257092793
-0.1958856419517666
NeuralFMU
Loss function with growing horizon
In order to train our model, a loss function must be implemented. The solver of the NeuralFMU can calculate the gradient of the loss function. The gradient descent is needed to adjust the weights in the neural network so that the sum of the error is reduced and the model becomes more accurate.
The loss function in this implementation consists of the mean squared error (mse) from the Fric position of the fricFMU simulation (posFric) and the position data of the network (posNet). $ e{mse} = \frac{1}{n} \sum\limits{i=0}^n (posFric[i] - posNet[i])^2 $ A growing horizon is applied, whereby the horizon only goes over the first five values. For this horizon the mse is calculated.
# loss function for training
global horizon = 5
function lossSum(p)
global posFric, neuralFMU, horizon
solution = neuralFMU(x₀, (tSave[1], tSave[horizon]); p=p, saveat=tSave[1:horizon]) # here, the NeuralODE is solved only for the time horizon
posNet = fmi2GetSolutionState(solution, 1; isIndex=true)
FMIFlux.Losses.mse(posFric[1:horizon], posNet)
end
lossSum (generic function with 1 method)
Function for plotting
In this section the function for plotting is defined. The function plotResults()
creates a new figure object. In dieses figure objekt werden dann die aktuellsten Ergebnisse von fricFMU, simpleFMU und neuralFMU gegenübergestellt.
To output the loss in certain time intervals, a callback is implemented as a function in the following. Here a counter is incremented, every twentieth pass the loss function is called and the average error is printed out.
function plotResults()
global neuralFMU
solNeural = neuralFMU(x₀, (tStart, tStop); saveat=tSave)
fig = Plots.plot(xlabel="t [s]", ylabel="mass position [m]", linewidth=2,
xtickfontsize=12, ytickfontsize=12,
xguidefontsize=12, yguidefontsize=12,
legendfontsize=8, legend=:topright)
plot!(fig, solSimple; stateIndices=1:1, values=false, label="SimpleFMU", linewidth=2)
plot!(fig, solFric; valueIndices=1:1, label="FricFMU", linewidth=2)
plot!(fig, solNeural; stateIndices=1:1, label="NeuralFMU", linewidth=2)
fig
end
plotResults (generic function with 1 method)
Callback
To output the loss in certain time intervals, a callback is implemented as a function in the following. Here a counter is incremented, every twentieth pass the loss function is called and the average error is printed out. As soon as a limit value (in this example 0.1
) is undershot, the horizon is extended by the next two values.
# callback function for training
global counter = 0
function callb(p)
global counter, horizon
counter += 1
if counter % 50 == 1
avgLoss = lossSum(p[1])
@info " Loss [$counter] for horizon $horizon : $(round(avgLoss, digits=5))\nAvg displacement in data: $(round(sqrt(avgLoss), digits=5))"
if avgLoss <= 0.01
horizon += 2
horizon = min(length(tSave), horizon)
end
# fig = plotResults()
# println("Figure update.")
# display(fig)
end
end
callb (generic function with 1 method)
Pre- and Postprocessing
In the following functions for pre-processing and post-processing are defined. The function preProc
is normalized the input values to mean of zero and a standard deviation of one.
global meanVal = 0.0
global stdVal = 0.0
function preProc!(data)
global meanVal, stdVal
meanVal = mean(data)
stdVal = std(data)
(data .- meanVal) ./ stdVal
end
preProc! (generic function with 1 method)
For post-processing, the previous normalization is undone by applying the calculation steps in reverse order.
function postProc!(data)
global meanVal, stdVal
(data .* stdVal) .+ meanVal
end
postProc! (generic function with 1 method)
Structure of the NeuralFMU
In the following, the topology of the NeuralFMU is constructed. It consists of an input layer, which then leads into the simpleFMU model. The ME-FMU computes the state derivatives for a given system state. Following the simpleFMU is a dense layer that has exactly as many inputs as the model has states (and therefore state derivatives). The output of this layer consists of 16 output nodes and a tanh activation function. The next layer has 16 input and output nodes with the same activation function. The last layer is again a dense layer with 16 input nodes and the number of states as outputs. Here, it is important that no tanh-activation function follows, because otherwise the pendulums state values would be limited to the interval $[-1;1]$.
# NeuralFMU setup
numStates = fmiGetNumberOfStates(simpleFMU)
additionalVRs = [fmi2StringToValueReference(simpleFMU, "mass.m")]
numAdditionalVRs = length(additionalVRs)
net = Chain(
x -> simpleFMU(x=x, dx_refs=:all, y_refs=additionalVRs),
preProc!,
Dense(numStates+numAdditionalVRs, 16, tanh),
postProc!,
preProc!,
Dense(16, 16, tanh),
postProc!,
preProc!,
Dense(16, numStates),
postProc!,
)
Chain(
var"#1#2"(),
preProc!,
Dense(3 => 16, tanh), [90m# 64 parameters[39m
postProc!,
preProc!,
Dense(16 => 16, tanh), [90m# 272 parameters[39m
postProc!,
preProc!,
Dense(16 => 2), [90m# 34 parameters[39m
postProc!,
) [90m # Total: 6 arrays, [39m370 parameters, 1.820 KiB.
Definition of the NeuralFMU
The instantiation of the ME-NeuralFMU is done as a one-liner. The FMU (simpleFMU), the structure of the network net
, start tStart
and end time tStop
, the numerical solver Tsit5()
and the time steps tSave
for saving are specified.
neuralFMU = ME_NeuralFMU(simpleFMU, net, (tStart, tStop), Tsit5(); saveat=tSave);
Plot before training
Here the state trajectory of the simpleFMU is recorded. Doesn't really look like a pendulum yet, but the system is random initialized by default. In the plots later on, the effect of learning can be seen.
solutionBefore = neuralFMU(x₀)
fmiPlot(solutionBefore)
[33m[1m┌ [22m[39m[33m[1mWarning: [22m[39mNo solver keyword detected for NeuralFMU.
[33m[1m│ [22m[39mContinuous adjoint method is applied, which requires solving backward in time.
[33m[1m│ [22m[39mThis might be not supported by every FMU.
[33m[1m│ [22m[39m(This message is only printed once.)
[33m[1m└ [22m[39m[90m@ FMICore C:\Users\runneradmin\.julia\packages\FMICore\7NIyu\src\printing.jl:38[39m
Training of the NeuralFMU
For the training of the NeuralFMU the parameters are extracted. The known Adam optimizer for minimizing the gradient descent is used as further passing parameters. In addition, the previously defined loss and callback function, as well as the number of epochs are passed.
# train
paramsNet = Flux.params(neuralFMU)
optim = Adam()
FMIFlux.train!(lossSum, neuralFMU, Iterators.repeated((), 1000), optim; cb=()->callb(paramsNet))
[36m[1m┌ [22m[39m[36m[1mInfo: [22m[39m Loss [1] for horizon 5 : 0.04198
[36m[1m└ [22m[39mAvg displacement in data: 0.20488
[36m[1m┌ [22m[39m[36m[1mInfo: [22m[39m Loss [51] for horizon 5 : 0.00035
[36m[1m└ [22m[39mAvg displacement in data: 0.01877
[36m[1m┌ [22m[39m[36m[1mInfo: [22m[39m Loss [101] for horizon 7 : 9.0e-5
[36m[1m└ [22m[39mAvg displacement in data: 0.00959
[36m[1m┌ [22m[39m[36m[1mInfo: [22m[39m Loss [151] for horizon 9 : 3.0e-5
[36m[1m└ [22m[39mAvg displacement in data: 0.00537
[36m[1m┌ [22m[39m[36m[1mInfo: [22m[39m Loss [201] for horizon 11 : 0.00016
[36m[1m└ [22m[39mAvg displacement in data: 0.01256
[36m[1m┌ [22m[39m[36m[1mInfo: [22m[39m Loss [251] for horizon 13 : 0.00057
[36m[1m└ [22m[39mAvg displacement in data: 0.02384
[36m[1m┌ [22m[39m[36m[1mInfo: [22m[39m Loss [301] for horizon 15 : 0.00129
[36m[1m└ [22m[39mAvg displacement in data: 0.03588
[36m[1m┌ [22m[39m[36m[1mInfo: [22m[39m Loss [351] for horizon 17 : 0.00114
[36m[1m└ [22m[39mAvg displacement in data: 0.03383
[36m[1m┌ [22m[39m[36m[1mInfo: [22m[39m Loss [401] for horizon 19 : 5.0e-5
[36m[1m└ [22m[39mAvg displacement in data: 0.00711
[36m[1m┌ [22m[39m[36m[1mInfo: [22m[39m Loss [451] for horizon 21 : 2.0e-5
[36m[1m└ [22m[39mAvg displacement in data: 0.0044
[36m[1m┌ [22m[39m[36m[1mInfo: [22m[39m Loss [501] for horizon 23 : 3.0e-5
[36m[1m└ [22m[39mAvg displacement in data: 0.00521
[36m[1m┌ [22m[39m[36m[1mInfo: [22m[39m Loss [551] for horizon 25 : 2.0e-5
[36m[1m└ [22m[39mAvg displacement in data: 0.0045
[36m[1m┌ [22m[39m[36m[1mInfo: [22m[39m Loss [601] for horizon 27 : 4.0e-5
[36m[1m└ [22m[39mAvg displacement in data: 0.00619
[36m[1m┌ [22m[39m[36m[1mInfo: [22m[39m Loss [651] for horizon 29 : 0.0001
[36m[1m└ [22m[39mAvg displacement in data: 0.00993
[36m[1m┌ [22m[39m[36m[1mInfo: [22m[39m Loss [701] for horizon 31 : 0.00024
[36m[1m└ [22m[39mAvg displacement in data: 0.01555
[36m[1m┌ [22m[39m[36m[1mInfo: [22m[39m Loss [751] for horizon 33 : 0.00042
[36m[1m└ [22m[39mAvg displacement in data: 0.02048
[36m[1m┌ [22m[39m[36m[1mInfo: [22m[39m Loss [801] for horizon 35 : 0.00043
[36m[1m└ [22m[39mAvg displacement in data: 0.02064
[36m[1m┌ [22m[39m[36m[1mInfo: [22m[39m Loss [851] for horizon 37 : 0.00038
[36m[1m└ [22m[39mAvg displacement in data: 0.01948
[36m[1m┌ [22m[39m[36m[1mInfo: [22m[39m Loss [901] for horizon 39 : 0.00034
[36m[1m└ [22m[39mAvg displacement in data: 0.01843
[36m[1m┌ [22m[39m[36m[1mInfo: [22m[39m Loss [951] for horizon 41 : 0.00031
[36m[1m└ [22m[39mAvg displacement in data: 0.01752
Comparison of the plots
Here three plots are compared with each other and only the position of the mass is considered. The first plot represents the simpleFMU, the second represents the fricFMU (reference) and the third plot represents the result after training the NeuralFMU.
# plot results mass.s
plotResults()
Finally, the FMU is cleaned-up.
fmiUnload(simpleFMU)
Summary
Based on the plots, it can be seen that the NeuralFMU is able to adapt the friction model of the fricFMU. After 1000 training steps, the curves already overlap quite well, but this can be further improved by longer training or a better initialization.
Source
[1] Tobias Thummerer, Lars Mikelsons and Josef Kircher. 2021. NeuralFMU: towards structural integration of FMUs into neural networks. Martin Sjölund, Lena Buffoni, Adrian Pop and Lennart Ochel (Ed.). Proceedings of 14th Modelica Conference 2021, Linköping, Sweden, September 20-24, 2021. Linköping University Electronic Press, Linköping (Linköping Electronic Conference Proceedings ; 181), 297-306. DOI: 10.3384/ecp21181297