knn example with digit recognition
This commit is contained in:
11
10_knn/README.md
Normal file
11
10_knn/README.md
Normal file
@@ -0,0 +1,11 @@
|
|||||||
|
## Digit recognition with knn
|
||||||
|
|
||||||
|
Here's some example Matlab code that shows KNN in action to guess handwritten digits. Here's what the output looks like:
|
||||||
|
|
||||||
|

|
||||||
|
|
||||||
|
Each row represents a guess. The last column contains the image that we're trying to guess the digit for. The first three columns show the 3 nearest neighbors for that last image.
|
||||||
|
|
||||||
|
Images with boxes around them represent images that we did not guess correctly. You can see KNN works pretty well -- with just 500 comparison images we are getting 80% accuracy. This jumps to 90% with 5000 comparison images (see the images/ directory).
|
||||||
|
|
||||||
|
[Uses the MNIST dataset](http://yann.lecun.com/exdb/mnist/)
|
||||||
BIN
10_knn/images/13_correct_50_comparisons.tif
Normal file
BIN
10_knn/images/13_correct_50_comparisons.tif
Normal file
Binary file not shown.
BIN
10_knn/images/16_correct_500_comparisons.tif
Normal file
BIN
10_knn/images/16_correct_500_comparisons.tif
Normal file
Binary file not shown.
BIN
10_knn/images/18_correct_5000_comparisons.tif
Normal file
BIN
10_knn/images/18_correct_5000_comparisons.tif
Normal file
Binary file not shown.
BIN
10_knn/images/19_correct_50000_comparisons.tif
Normal file
BIN
10_knn/images/19_correct_50000_comparisons.tif
Normal file
Binary file not shown.
79
10_knn/main.m
Normal file
79
10_knn/main.m
Normal file
@@ -0,0 +1,79 @@
|
|||||||
|
function foo = main()
|
||||||
|
addpath('mnistHelper');
|
||||||
|
train_images = loadMNISTImages('train-images-idx3-ubyte');
|
||||||
|
train_labels = loadMNISTLabels('train-labels-idx1-ubyte');
|
||||||
|
|
||||||
|
test_images = loadMNISTImages('t10k-images-idx3-ubyte');
|
||||||
|
test_labels = loadMNISTLabels('t10k-labels-idx1-ubyte');
|
||||||
|
|
||||||
|
% showData(train_images, 100, 100);
|
||||||
|
guesses(20, 3, train_images, train_labels, test_images, test_labels);
|
||||||
|
end
|
||||||
|
|
||||||
|
function foo = showData(images, rows, cols)
|
||||||
|
grid = [];
|
||||||
|
|
||||||
|
i = 0;
|
||||||
|
for x = 1:rows
|
||||||
|
imgs = [];
|
||||||
|
for y = 1:cols
|
||||||
|
i = i + 1;
|
||||||
|
imgs = [imgs reshape(images(:, i), 28, 28)];
|
||||||
|
end
|
||||||
|
grid = [grid; imgs];
|
||||||
|
end
|
||||||
|
imshow(grid);
|
||||||
|
end
|
||||||
|
|
||||||
|
function d = distance(train_image, test_image)
|
||||||
|
v = train_image - test_image;
|
||||||
|
v = double(v);
|
||||||
|
d = sqrt(v * v');
|
||||||
|
end
|
||||||
|
|
||||||
|
function result = border(image, value)
|
||||||
|
image = reshape(image, 28, 28);
|
||||||
|
result = zeros(28, 28);
|
||||||
|
result(:, :) = value;
|
||||||
|
result(2:27, 2:27) = image(2:27, 2:27);
|
||||||
|
result = reshape(result, 784, 1);
|
||||||
|
end
|
||||||
|
|
||||||
|
function foo = guesses(count, k, train_images, train_labels, test_images, test_labels)
|
||||||
|
[foo num_train_images] = size(train_images);
|
||||||
|
[foo num_test_images] = size(test_images);
|
||||||
|
|
||||||
|
correct = 0;
|
||||||
|
|
||||||
|
grid = [];
|
||||||
|
for x_ = 1:count
|
||||||
|
x = floor(rand() * num_test_images);
|
||||||
|
test_image = test_images(:, x);
|
||||||
|
correct_label = test_labels(x);
|
||||||
|
dist = [];
|
||||||
|
num_train_images = 50000;
|
||||||
|
for i = 1:num_train_images
|
||||||
|
|
||||||
|
train_image = train_images(:, i);
|
||||||
|
d = distance(train_image', test_image');
|
||||||
|
dist = [dist; [d i]];
|
||||||
|
end
|
||||||
|
sorted_ = (sortrows(dist, 1));
|
||||||
|
sorted = sorted_(1:k, :);
|
||||||
|
labels = [];
|
||||||
|
for i = 1:k
|
||||||
|
grid = [grid train_images(:, sorted(i, 2))];
|
||||||
|
labels = [labels train_labels(sorted(i, 2))];
|
||||||
|
end
|
||||||
|
guess_label = mode(labels);
|
||||||
|
if guess_label == correct_label
|
||||||
|
correct = correct + 1;
|
||||||
|
grid = [grid test_image];
|
||||||
|
else
|
||||||
|
grid = [grid border(test_image, 255)];
|
||||||
|
end
|
||||||
|
end
|
||||||
|
correct
|
||||||
|
|
||||||
|
showData(grid, count, k+1);
|
||||||
|
end
|
||||||
26
10_knn/mnistHelper/loadMNISTImages.m
Normal file
26
10_knn/mnistHelper/loadMNISTImages.m
Normal file
@@ -0,0 +1,26 @@
|
|||||||
|
function images = loadMNISTImages(filename)
|
||||||
|
%loadMNISTImages returns a 28x28x[number of MNIST images] matrix containing
|
||||||
|
%the raw MNIST images
|
||||||
|
|
||||||
|
fp = fopen(filename, 'rb');
|
||||||
|
assert(fp ~= -1, ['Could not open ', filename, '']);
|
||||||
|
|
||||||
|
magic = fread(fp, 1, 'int32', 0, 'ieee-be');
|
||||||
|
assert(magic == 2051, ['Bad magic number in ', filename, '']);
|
||||||
|
|
||||||
|
numImages = fread(fp, 1, 'int32', 0, 'ieee-be');
|
||||||
|
numRows = fread(fp, 1, 'int32', 0, 'ieee-be');
|
||||||
|
numCols = fread(fp, 1, 'int32', 0, 'ieee-be');
|
||||||
|
|
||||||
|
images = fread(fp, inf, 'unsigned char');
|
||||||
|
images = reshape(images, numCols, numRows, numImages);
|
||||||
|
images = permute(images,[2 1 3]);
|
||||||
|
|
||||||
|
fclose(fp);
|
||||||
|
|
||||||
|
% Reshape to #pixels x #examples
|
||||||
|
images = reshape(images, size(images, 1) * size(images, 2), size(images, 3));
|
||||||
|
% Convert to double and rescale to [0,1]
|
||||||
|
images = double(images) / 255;
|
||||||
|
|
||||||
|
end
|
||||||
19
10_knn/mnistHelper/loadMNISTLabels.m
Normal file
19
10_knn/mnistHelper/loadMNISTLabels.m
Normal file
@@ -0,0 +1,19 @@
|
|||||||
|
function labels = loadMNISTLabels(filename)
|
||||||
|
%loadMNISTLabels returns a [number of MNIST images]x1 matrix containing
|
||||||
|
%the labels for the MNIST images
|
||||||
|
|
||||||
|
fp = fopen(filename, 'rb');
|
||||||
|
assert(fp ~= -1, ['Could not open ', filename, '']);
|
||||||
|
|
||||||
|
magic = fread(fp, 1, 'int32', 0, 'ieee-be');
|
||||||
|
assert(magic == 2049, ['Bad magic number in ', filename, '']);
|
||||||
|
|
||||||
|
numLabels = fread(fp, 1, 'int32', 0, 'ieee-be');
|
||||||
|
|
||||||
|
labels = fread(fp, inf, 'unsigned char');
|
||||||
|
|
||||||
|
assert(size(labels,1) == numLabels, 'Mismatch in label count');
|
||||||
|
|
||||||
|
fclose(fp);
|
||||||
|
|
||||||
|
end
|
||||||
BIN
10_knn/t10k-images-idx3-ubyte
Normal file
BIN
10_knn/t10k-images-idx3-ubyte
Normal file
Binary file not shown.
BIN
10_knn/t10k-labels-idx1-ubyte
Normal file
BIN
10_knn/t10k-labels-idx1-ubyte
Normal file
Binary file not shown.
BIN
10_knn/train-images-idx3-ubyte
Normal file
BIN
10_knn/train-images-idx3-ubyte
Normal file
Binary file not shown.
BIN
10_knn/train-labels-idx1-ubyte
Normal file
BIN
10_knn/train-labels-idx1-ubyte
Normal file
Binary file not shown.
Reference in New Issue
Block a user