Skip to content

Instantly share code, notes, and snippets.

@bartolsthoorn
Created April 29, 2017 12:13
Show Gist options
  • Select an option

  • Save bartolsthoorn/36c813a4becec1b260392f5353c8b7cc to your computer and use it in GitHub Desktop.

Select an option

Save bartolsthoorn/36c813a4becec1b260392f5353c8b7cc to your computer and use it in GitHub Desktop.
Simple multi-laber classification example with Pytorch and MultiLabelSoftMarginLoss (https://en.wikipedia.org/wiki/Multi-label_classification)
import torch
import torch.nn as nn
import numpy as np
import torch.optim as optim
from torch.autograd import Variable
# (1, 0) => target labels 0+2
# (0, 1) => target labels 1
# (1, 1) => target labels 3
train = []
labels = []
for i in range(10000):
category = (np.random.choice([0, 1]), np.random.choice([0, 1]))
if category == (1, 0):
train.append([np.random.uniform(0.1, 1), 0])
labels.append([1, 0, 1])
if category == (0, 1):
train.append([0, np.random.uniform(0.1, 1)])
labels.append([0, 1, 0])
if category == (0, 0):
train.append([np.random.uniform(0.1, 1), np.random.uniform(0.1, 1)])
labels.append([0, 0, 1])
class _classifier(nn.Module):
def __init__(self, nlabel):
super(_classifier, self).__init__()
self.main = nn.Sequential(
nn.Linear(2, 64),
nn.ReLU(),
nn.Linear(64, nlabel),
)
def forward(self, input):
return self.main(input)
nlabel = len(labels[0]) # => 3
classifier = _classifier(nlabel)
optimizer = optim.Adam(classifier.parameters())
criterion = nn.MultiLabelSoftMarginLoss()
epochs = 5
for epoch in range(epochs):
losses = []
for i, sample in enumerate(train):
inputv = Variable(torch.FloatTensor(sample)).view(1, -1)
labelsv = Variable(torch.FloatTensor(labels[i])).view(1, -1)
output = classifier(inputv)
loss = criterion(output, labelsv)
optimizer.zero_grad()
loss.backward()
optimizer.step()
losses.append(loss.data.mean())
print('[%d/%d] Loss: %.3f' % (epoch+1, epochs, np.mean(losses)))
$ python multilabel.py
[1/5] Loss: 0.092
[2/5] Loss: 0.005
[3/5] Loss: 0.001
[4/5] Loss: 0.000
[5/5] Loss: 0.000
@bartolsthoorn
Copy link
Author

@jcfgonc No. You likely confused nn.MultiLabelSoftMarginLoss with nn.MultiLabelMarginLoss (note the Soft in the name). Despite the similar names, they require totally different label formats.

Loss Function Use Case Expected Label Format Example Label
nn.MultiLabelSoftMarginLoss Multi-Label Classification Multi-hot Vector (0s and 1s) [1, 0, 1]
(Class 0 and 2 are present)
nn.MultiLabelMarginLoss Multi-Label Classification Vector of Class Indices [0, 2, -1]
(Indices of active classes; padded with -1)
nn.CrossEntropyLoss Single-Class Classification Single Class Index (Scalar) 2
(Only Class 2 is present)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment