diff --git a/10_knn/README.md b/10_knn/README.md new file mode 100644 index 0000000..b84a5ff --- /dev/null +++ b/10_knn/README.md @@ -0,0 +1,11 @@ +## Digit recognition with knn + +Here's some example Matlab code that shows KNN in action to guess handwritten digits. Here's what the output looks like: + +![](/images/16_correct_500_comparisons.tif) + +Each row represents a guess. The last column contains the image that we're trying to guess the digit for. The first three columns show the 3 nearest neighbors for that last image. + +Images with boxes around them represent images that we did not guess correctly. You can see KNN works pretty well -- with just 500 comparison images we are getting 80% accuracy. This jumps to 90% with 5000 comparison images (see the images/ directory). + +[Uses the MNIST dataset](http://yann.lecun.com/exdb/mnist/) diff --git a/10_knn/images/13_correct_50_comparisons.tif b/10_knn/images/13_correct_50_comparisons.tif new file mode 100644 index 0000000..9a31279 Binary files /dev/null and b/10_knn/images/13_correct_50_comparisons.tif differ diff --git a/10_knn/images/16_correct_500_comparisons.tif b/10_knn/images/16_correct_500_comparisons.tif new file mode 100644 index 0000000..3022027 Binary files /dev/null and b/10_knn/images/16_correct_500_comparisons.tif differ diff --git a/10_knn/images/18_correct_5000_comparisons.tif b/10_knn/images/18_correct_5000_comparisons.tif new file mode 100644 index 0000000..c1a7b93 Binary files /dev/null and b/10_knn/images/18_correct_5000_comparisons.tif differ diff --git a/10_knn/images/19_correct_50000_comparisons.tif b/10_knn/images/19_correct_50000_comparisons.tif new file mode 100644 index 0000000..13ab748 Binary files /dev/null and b/10_knn/images/19_correct_50000_comparisons.tif differ diff --git a/10_knn/main.m b/10_knn/main.m new file mode 100644 index 0000000..b92e8a9 --- /dev/null +++ b/10_knn/main.m @@ -0,0 +1,79 @@ +function foo = main() + addpath('mnistHelper'); + train_images = loadMNISTImages('train-images-idx3-ubyte'); + train_labels = loadMNISTLabels('train-labels-idx1-ubyte'); + + test_images = loadMNISTImages('t10k-images-idx3-ubyte'); + test_labels = loadMNISTLabels('t10k-labels-idx1-ubyte'); + + % showData(train_images, 100, 100); + guesses(20, 3, train_images, train_labels, test_images, test_labels); +end + +function foo = showData(images, rows, cols) + grid = []; + + i = 0; + for x = 1:rows + imgs = []; + for y = 1:cols + i = i + 1; + imgs = [imgs reshape(images(:, i), 28, 28)]; + end + grid = [grid; imgs]; + end + imshow(grid); +end + +function d = distance(train_image, test_image) + v = train_image - test_image; + v = double(v); + d = sqrt(v * v'); +end + +function result = border(image, value) + image = reshape(image, 28, 28); + result = zeros(28, 28); + result(:, :) = value; + result(2:27, 2:27) = image(2:27, 2:27); + result = reshape(result, 784, 1); +end + +function foo = guesses(count, k, train_images, train_labels, test_images, test_labels) + [foo num_train_images] = size(train_images); + [foo num_test_images] = size(test_images); + + correct = 0; + + grid = []; + for x_ = 1:count + x = floor(rand() * num_test_images); + test_image = test_images(:, x); + correct_label = test_labels(x); + dist = []; + num_train_images = 50000; + for i = 1:num_train_images + + train_image = train_images(:, i); + d = distance(train_image', test_image'); + dist = [dist; [d i]]; + end + sorted_ = (sortrows(dist, 1)); + sorted = sorted_(1:k, :); + labels = []; + for i = 1:k + grid = [grid train_images(:, sorted(i, 2))]; + labels = [labels train_labels(sorted(i, 2))]; + end + guess_label = mode(labels); + if guess_label == correct_label + correct = correct + 1; + grid = [grid test_image]; + else + grid = [grid border(test_image, 255)]; + end + end + correct + + showData(grid, count, k+1); +end diff --git a/10_knn/mnistHelper/loadMNISTImages.m b/10_knn/mnistHelper/loadMNISTImages.m new file mode 100644 index 0000000..6eb2304 --- /dev/null +++ b/10_knn/mnistHelper/loadMNISTImages.m @@ -0,0 +1,26 @@ +function images = loadMNISTImages(filename) +%loadMNISTImages returns a 28x28x[number of MNIST images] matrix containing +%the raw MNIST images + +fp = fopen(filename, 'rb'); +assert(fp ~= -1, ['Could not open ', filename, '']); + +magic = fread(fp, 1, 'int32', 0, 'ieee-be'); +assert(magic == 2051, ['Bad magic number in ', filename, '']); + +numImages = fread(fp, 1, 'int32', 0, 'ieee-be'); +numRows = fread(fp, 1, 'int32', 0, 'ieee-be'); +numCols = fread(fp, 1, 'int32', 0, 'ieee-be'); + +images = fread(fp, inf, 'unsigned char'); +images = reshape(images, numCols, numRows, numImages); +images = permute(images,[2 1 3]); + +fclose(fp); + +% Reshape to #pixels x #examples +images = reshape(images, size(images, 1) * size(images, 2), size(images, 3)); +% Convert to double and rescale to [0,1] +images = double(images) / 255; + +end diff --git a/10_knn/mnistHelper/loadMNISTLabels.m b/10_knn/mnistHelper/loadMNISTLabels.m new file mode 100644 index 0000000..06c07a4 --- /dev/null +++ b/10_knn/mnistHelper/loadMNISTLabels.m @@ -0,0 +1,19 @@ +function labels = loadMNISTLabels(filename) +%loadMNISTLabels returns a [number of MNIST images]x1 matrix containing +%the labels for the MNIST images + +fp = fopen(filename, 'rb'); +assert(fp ~= -1, ['Could not open ', filename, '']); + +magic = fread(fp, 1, 'int32', 0, 'ieee-be'); +assert(magic == 2049, ['Bad magic number in ', filename, '']); + +numLabels = fread(fp, 1, 'int32', 0, 'ieee-be'); + +labels = fread(fp, inf, 'unsigned char'); + +assert(size(labels,1) == numLabels, 'Mismatch in label count'); + +fclose(fp); + +end diff --git a/10_knn/t10k-images-idx3-ubyte b/10_knn/t10k-images-idx3-ubyte new file mode 100644 index 0000000..1170b2c Binary files /dev/null and b/10_knn/t10k-images-idx3-ubyte differ diff --git a/10_knn/t10k-labels-idx1-ubyte b/10_knn/t10k-labels-idx1-ubyte new file mode 100644 index 0000000..d1c3a97 Binary files /dev/null and b/10_knn/t10k-labels-idx1-ubyte differ diff --git a/10_knn/train-images-idx3-ubyte b/10_knn/train-images-idx3-ubyte new file mode 100644 index 0000000..bbce276 Binary files /dev/null and b/10_knn/train-images-idx3-ubyte differ diff --git a/10_knn/train-labels-idx1-ubyte b/10_knn/train-labels-idx1-ubyte new file mode 100644 index 0000000..d6b4c5d Binary files /dev/null and b/10_knn/train-labels-idx1-ubyte differ