In this work, the authors expand on a recent paper about representation networks with sine active functions and try to make such representations generate a latent space. In a nutshell, they add extra dimensions to the input space and use the gradient of the loss function in regards to these extra parameters as a first approximation which is then fed to a more global loss function.
So there are actually two loss functions, one calling the other and the network function F ends up being called twice in the global loss function.
Anyway, I'm not sure how I could implement this in Mathematica.
First, I would need to compute the gradient of the inner loss function. Then I would have to write the global loss function and provide it to the NetTrain method, I suppose? Would Mathematica be able to automatically do the differentiation?
Here is my code to train one single image, that is training a SIREN without any attempt at implementing the Gradient Origin Network. It works well and quick with such tiny (28x28) image.
mnist = ResourceObject["MNIST"]
trainingData = ResourceData[mnist, "TrainingData"]
img = Keys[trainingData] // RandomChoice
linspace[n_] := Range[(1 - n)/2, (n - 1)/2];
grid[w_, h_] := Pi/2 Tuples[{linspace[h], linspace[w]}] // N;
input = grid @@ ImageDimensions[img]
output = Partition[Flatten[2 # - 1 &@ImageData@img], 1]
net = NetInitialize[
NetChain[
28 // {
#, Sin,
#, Sin,
1, Sin
} &,
"Input" -> 2
],
];
net = NetTrain[net, (input -> output)]
Image@Partition[(# + 1)/2 &[net /@ input], 28]
How would I go about expanding this code to use the method describe in the paper so that I build a latent space to represent all the MNIST digit dataset ?
PS. I suppose I can define a gradient function.
latentSpaceSize = 32;
\[Epsilon] = .000001;
grad[x_, y_] :=
Flatten@Table[
Function[
{f, v, \[Epsilon]i},
(f[v + \[Epsilon]i] - f[v - \[Epsilon]i])/(2 \[Epsilon])
][
net@# &,
{x, y}~Join~ConstantArray[0, latentSpaceSize],
\[Epsilon] UnitVector[2 + latentSpaceSize, i]
],
{i, 2 + Range[latentSpaceSize]}
]
grad[0, 0]
I would also have to augment the size of the input :
input = Table[
xy ~Join~ConstantArray[0, latentSpaceSize],
{xy, grid @@ ImageDimensions[img]}
];
net = NetInitialize[
NetChain[
28 // {
#, Sin,
#, Sin,
1, Sin
} &,
"Input" -> 2 + latentSpaceSize
]
];