(Cross-posting answer from my answer in StackExchange) I will make a complete example of making predictions with exported nets in MXNet (Python binding). First,
SeedRandom[1234];
net = NetChain[{LinearLayer[5], LinearLayer[3]}, "Input" -> 2];
net = NetInitialize[net];
net[{1, 2}]
{1.27735, -1.21455, -1.02647}
Exporting:
Export["example.json", net, "MXNet"]
creates two files, "example.json"
and "example.params"
. Now lets run this in MXNet. The first thing to note: this is designed to be used by the lowest-level Symbol and NDArray API, NOT higher-level APIs like Gluon or Model. These higher-level APIs are often Python specific: using the lowest level means that this will work on every language binding MXNet supports.
The first step is to import the symbol (stored as a json file) and NDArray:
import mxnet as mx
import numpy as np
sym = mx.symbol.load('example.json')
nd = mx.nd.load('example.params')
If you tried to bind now via e = sym.bind(mx.cpu(), nd)
, it will fail, as you have not specified an input NDArray, which will tell MXNet what shape the executor will be. Rather, see what the different inputs are to the symbol:
sym.list_arguments()
['Input', '1.Weights', '1.Biases', '2.Weights', '2.Biases']
In MXNet, weights and biases are normal input arguments. Looking at the keys in our NDArray list:
list(nd)
['2.Weights', '2.Biases', '1.Biases', '1.Weights']
we see it does not contain an array for "Input"
. Lets add it:
inputND = mx.nd.array(np.array([[1,2]]));
nd["Input"] = inputND
Why did I use [[1,2]]
and not [1, 2]
? The answer is that MXNet always expects a batch dimension. So this example is using a batch size of 1. Now bind and do a forward pass:
e = sym.bind(mx.cpu(), nd)
out = e.forward()
This returns a list of NDArrays:
[
[[ 1.27735424 -1.21454573 -1.02646875]]
<NDArray 1x3 @cpu(0)>]
Which can be converted to a numpy array via out[0].asnumpy()
. Further predictions can be made by directly setting the "Input"
NDArray:
inputND[:] = np.array([[3,4]])
a = e.forward()
returning:
[
[[ 3.56568646 -3.15509319 -3.13344383]]
<NDArray 1x3 @cpu(0)>]