In the documentation there is a tutorial about doing neural network regression with uncertainty. This approach works under certain circumstances, but it can be difficult to generalize, so I started looking for other ways to do it.
As it turns out, there is a link between regression neural networks and Gaussian processes which can be exploited to put error bands on the predictions (see, e.g., this post by Yarin Gal and the PhD thesis by R.M. Neal 1995). The basic idea here is to use the DropoutLayer
to create a noisy neural network which can be sampled multiple times to get a sense of the errors in the predictions (though it's not quite as simple as I'm making it sound here).
Inspired by Yarin's post above and his interactive example of a network that is continuously being retrained on the example data, I decided to do something similar in Mathematica. The result is the code below, which generates an interactive example in which you can edit the network's training data (by clicking in the figure) and adjust the network parameters with controls. I had some trouble getting the code to not cause strange front end issues, but it seems to work quite well now.
In the attached notebook I go into a bit more detail of my implementation of this method and also show how to do regression with a non-constant noise level. I hope this is of some use to anyone here :)
Example 1: fitting with a network that assumes a constant noise level (mean + 1 sigma error bars)
Example 2: fitting with a network that fits the noise level to the data (heteroscedastic regression)
Code for interactive example:
DynamicModule[{
exampleData,
net ,
prob = 0.2,
\[Lambda] = 0.01,
rounds = 10,
sampleNumber = 100,
samples,
l = 2,
nlayers = 300,
activation = Ramp,
init,
sampleNet,
xmin = -5,
xmax = 5,
ymin = -2,
ymax = 2
},
exampleData = Table[ (*initialise training data *)
{x, Sin[x]} + RandomVariate[NormalDistribution[0, .15]],
{x, RandomVariate[UniformDistribution[{-3, 3}], 15]}
];
(* Function to sample the noisy net multiple times and calculate mean + \
stdev *)
sampleNet[net_NetChain, xvalues_List, sampleNumber_Integer?Positive] :=
PreemptProtect[
TimeSeries[
Map[
With[{
mean = Mean[#],
stdv = Sqrt[Variance[#] + (2 \[Lambda] Length[exampleData])/(l^2 (1 - prob))]
},
mean + stdv*{-1, 0, 1}
] &,
Transpose@Select[
Table[
net[xvalues, NetEvaluationMode -> "Train"],
{i, sampleNumber}
],
ListQ
]
],
{xvalues},
ValueDimensions -> 3
]
];
(*network initialisation function. Necessary when one of the network \
parameters is changed *)
init[] := PreemptProtect[
net = NetInitialize@NetChain[
{
LinearLayer[nlayers],
ElementwiseLayer[activation],
DropoutLayer[prob],
1
},
"Input" -> "Scalar",
"Output" -> "Scalar"
]
];
init[];
samples = sampleNet[net, N@Subdivide[xmin, xmax, 100], sampleNumber];
DynamicWrapper[
Grid[{
(* Controls *)
{
Labeled[Manipulator[Dynamic[l], {0.01, 10}],
Tooltip["l", "GP prior length scale"], Right],
Labeled[Manipulator[Dynamic[\[Lambda]], {0.0001, 0.1}],
Tooltip["\[Lambda]", "L2 regularisation coefficient"], Right]
},
{
Labeled[Manipulator[Dynamic[sampleNumber], {10, 500, 1}], "# samples",
Right],
SpanFromLeft
},
{
Labeled[Manipulator[Dynamic[prob], {0, 0.95}, ContinuousAction -> False],
Tooltip["p", "Dropout probability"], Right],
Labeled[
Manipulator[Dynamic[nlayers], {20, 500, 1}, ContinuousAction -> False],
"# layers", Right]
},
{
Labeled[
PopupMenu[
Dynamic[activation],
{
Ramp, Tanh, ArcTan, LogisticSigmoid, "ExponentialLinearUnit",
"ScaledExponentialLinearUnit",
"SoftSign", "SoftPlus", "HardTanh", "HardSigmoid"
},
ContinuousAction -> False
],
"Activation function"
,
Right
],
(* This resets the network if one of the network parameters changes *)
DynamicWrapper[
"",
init[],
SynchronousUpdating -> False,
TrackedSymbols :> {activation, prob, nlayers}
]
},
(* Main contents *)
{
Labeled[
LocatorPane[
Dynamic[exampleData],
Dynamic[
Show[
ListPlot[exampleData, PlotStyle -> Red],
ListPlot[
samples,
Joined -> True,
Filling -> {1 -> {2}, 3 -> {2}},
PlotStyle -> {Lighter[Blue], Blue, Lighter[Blue]}
],
ImageSize -> 600,
PlotRange -> {{xmin, xmax}, {ymin, ymax}}
],
TrackedSymbols :> {samples, exampleData}
],
ContinuousAction -> False,
LocatorAutoCreate -> All
],
"1 \[Sigma] error bands (\[AltKey] + click to delete points)",
Top
],
SpanFromLeft
}
},
BaseStyle -> "Text",
Alignment -> Left
],
(* Continuously retrain the net on the current examples and resample the \
network *)
net = Quiet@With[{
new = NetTrain[
net,
<|
"Input" -> exampleData[[All, 1]],
"Output" -> exampleData[[All, 2]]
|>,
LossFunction -> MeanSquaredLossLayer[],
Method -> {"ADAM", "L2Regularization" -> \[Lambda], "LearningRate" -> 0.005},
MaxTrainingRounds -> rounds,
TrainingProgressReporting -> None
]
},
If[ Head[new] === NetChain, new, net]
];
samples = sampleNet[net, N@Subdivide[xmin, xmax, 50], sampleNumber],
SynchronousUpdating -> False
]
]
Attachments: