Group Abstract Group Abstract

Message Boards Message Boards

Machine learning with weighted data

POSTED BY: Seth Chandler
2 Replies

Thanks for an extremely helpful and lucid response.

POSTED BY: Seth Chandler

Dear Seth,

Here is a way to handle weighted data with NetTrain in the current framework :


Define a weighted loss:

weightedCrossEntropy = NetGraph[
  <|"time" -> ThreadingLayer[Times], 
   "loss" -> CrossEntropyLossLayer["Index"]|>,
  {{NetPort["Weight"], "loss"} -> "time"}]

enter image description here

Make a graph with your network and this weighted loss:

netWithLoss = NetGraph[
  {net, weightedCrossEntropy},
  {1 -> NetPort[2, "Input"], 2 -> NetPort["Loss"]},
  "Target" -> NetEncoder[{"Class", {"evil", "good"}}]]

enter image description here

Suppose you have some weights:

trainingWeights = RandomReal[{-1, 1}, Length@First@training];

Train with these weights:

netTrainedWithLoss = 
 NetTrain[netWithLoss, <|"Weight" -> trainingWeights, 
   "Input" -> First@training, "Target" -> Last@training|>, 
  ValidationSet -> Scaled[0.25]]

enter image description here

Extract the trained net from the graph (Caution: NetEncoder & NetDecoder have to be "re-attached"):

netTrained = 
 NetReplacePart[NetExtract[netTrainedWithLoss, 1], 
  "Output" -> NetDecoder[{"Class", {"evil", "good"}}]]

enter image description here

Here you go!


There will be a more straightforward way to do this in the future.

It should also be supported by Classify and Predict at some point. The way to efficiently handle weights is actually specific to each classification method. And as you see here, there is a natural way of doing it with neural networks and any gradient-based method. Note that over-sampling data with higher weights (like in SMOTE) is an option that would work with most of the methods, but it can be awkward (rough approximation of the weights, higher computational and memory usage).

Reply to this discussion
Community posts can be styled and formatted using the Markdown syntax.
Reply Preview
Attachments
Remove
or Discard