~dvshkn/experiment-cifar10-training-order

340456dee5116812cdeb0c907ffae1e380b80740 — David Knight 1 year, 14 days ago
initial commit
63 files changed, 942 insertions(+), 0 deletions(-)

A .gitignore
A cifar10_training_order.ipynb
A cifar10_training_order_test_only.ipynb
A plots/cifar10_training_order_accuracies-entropy-global-asc-all.png
A plots/cifar10_training_order_accuracies-entropy-global-desc-all.png
A plots/cifar10_training_order_accuracies-entropy-grouped-asc-all.png
A plots/cifar10_training_order_accuracies-entropy-grouped-desc-all.png
A plots/cifar10_training_order_accuracies-random-all.png
A plots/cifar10_training_order_losses-entropy-global-asc-all.png
A plots/cifar10_training_order_losses-entropy-global-desc-all.png
A plots/cifar10_training_order_losses-entropy-grouped-asc-all.png
A plots/cifar10_training_order_losses-entropy-grouped-desc-all.png
A plots/cifar10_training_order_losses-random-all.png
A training/cifar10_training_order-entropy-global-asc-0.1.net
A training/cifar10_training_order-entropy-global-asc-0.2.net
A training/cifar10_training_order-entropy-global-asc-0.3.net
A training/cifar10_training_order-entropy-global-asc-0.4.net
A training/cifar10_training_order-entropy-global-asc-0.5.net
A training/cifar10_training_order-entropy-global-asc-0.6.net
A training/cifar10_training_order-entropy-global-asc-0.7.net
A training/cifar10_training_order-entropy-global-asc-0.8.net
A training/cifar10_training_order-entropy-global-asc-0.9.net
A training/cifar10_training_order-entropy-global-asc-1.net
A training/cifar10_training_order-entropy-global-desc-0.1.net
A training/cifar10_training_order-entropy-global-desc-0.2.net
A training/cifar10_training_order-entropy-global-desc-0.3.net
A training/cifar10_training_order-entropy-global-desc-0.4.net
A training/cifar10_training_order-entropy-global-desc-0.5.net
A training/cifar10_training_order-entropy-global-desc-0.6.net
A training/cifar10_training_order-entropy-global-desc-0.7.net
A training/cifar10_training_order-entropy-global-desc-0.8.net
A training/cifar10_training_order-entropy-global-desc-0.9.net
A training/cifar10_training_order-entropy-global-desc-1.net
A training/cifar10_training_order-entropy-grouped-asc-0.1.net
A training/cifar10_training_order-entropy-grouped-asc-0.2.net
A training/cifar10_training_order-entropy-grouped-asc-0.3.net
A training/cifar10_training_order-entropy-grouped-asc-0.4.net
A training/cifar10_training_order-entropy-grouped-asc-0.5.net
A training/cifar10_training_order-entropy-grouped-asc-0.6.net
A training/cifar10_training_order-entropy-grouped-asc-0.7.net
A training/cifar10_training_order-entropy-grouped-asc-0.8.net
A training/cifar10_training_order-entropy-grouped-asc-0.9.net
A training/cifar10_training_order-entropy-grouped-asc-1.net
A training/cifar10_training_order-entropy-grouped-desc-0.1.net
A training/cifar10_training_order-entropy-grouped-desc-0.2.net
A training/cifar10_training_order-entropy-grouped-desc-0.3.net
A training/cifar10_training_order-entropy-grouped-desc-0.4.net
A training/cifar10_training_order-entropy-grouped-desc-0.5.net
A training/cifar10_training_order-entropy-grouped-desc-0.6.net
A training/cifar10_training_order-entropy-grouped-desc-0.7.net
A training/cifar10_training_order-entropy-grouped-desc-0.8.net
A training/cifar10_training_order-entropy-grouped-desc-0.9.net
A training/cifar10_training_order-entropy-grouped-desc-1.net
A training/cifar10_training_order-random-0.1.net
A training/cifar10_training_order-random-0.2.net
A training/cifar10_training_order-random-0.3.net
A training/cifar10_training_order-random-0.4.net
A training/cifar10_training_order-random-0.5.net
A training/cifar10_training_order-random-0.6.net
A training/cifar10_training_order-random-0.7.net
A training/cifar10_training_order-random-0.8.net
A training/cifar10_training_order-random-0.9.net
A training/cifar10_training_order-random-1.net
A  => .gitignore +3 -0
@@ 1,3 @@
.DS_Store
.ipynb_checkpoints/
data/

A  => cifar10_training_order.ipynb +498 -0
@@ 1,498 @@
{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Experiment: CIFAR-10 Training Order\n",
    "## To what extent does training order affect image classifier performance?\n",
    "[Code is based on this PyTorch tutorial!](https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The various constants below control the run mode of the experiment.\n",
    "- `ENTROPY_SORT_ENABLED`: This enables or disables entropy sorting of training data. No entropy sorting is the control case and makes this notebook equivalent to the original PyTorch tutorial.\n",
    "- `SORT_GROUPED`: If true this causes each CIFAR-10 image class to be entropy sorted independently. If false this applies entropy sorting globally over all training data. This constant is ignored if `ENTROPY_SORT_ENABLED` is false.\n",
    "- `SORT_ASCENDING`: Entropy sorting is applied in ascending order if true or descending order if false. This constant is ignored if `ENTROPY_SORT_ENABLED` is false.\n",
    "- `TRAINING_SLICE_FACTORS`: For every number (0,1] in this array a net will be trained using that fraction of the CIFAR-10 training set."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "#\n",
    "# Experiment Controls\n",
    "#\n",
    "\n",
    "ENTROPY_SORT_ENABLED = True\n",
    "SORT_GROUPED = True\n",
    "SORT_ASCENDING = True\n",
    "TRAINING_SLICE_FACTORS = [1, 0.9, 0.8, 0.7, 0.6, 0.5, 0.4, 0.3, 0.2, 0.1]\n",
    "\n",
    "def get_file_suffix(slice_factor):\n",
    "    if ENTROPY_SORT_ENABLED:\n",
    "        grouped_str = 'grouped' if SORT_GROUPED else 'global'\n",
    "        ascending_str = 'asc' if SORT_ASCENDING else 'desc'\n",
    "        return f'entropy-{grouped_str}-{ascending_str}-{slice_factor}'\n",
    "    else:\n",
    "        return f'random-{slice_factor}'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import torch.optim as optim\n",
    "import torchvision\n",
    "import torchvision.transforms as transforms\n",
    "import copy\n",
    "\n",
    "class EntropySampler(torch.utils.data.Sampler):\n",
    "    def entropy(self, datum):\n",
    "        arr = datum[0].numpy()\n",
    "        # grayscale coefficients from matlab rgb2gray\n",
    "        arr = 0.2989 * arr[0,:,:] + 0.5870 * arr[1,:,:] + 0.1140 * arr[2,:,:]\n",
    "        hist = np.histogram(arr, bins=256)[0]\n",
    "        hist = hist / np.sum(hist)\n",
    "        hist = hist[hist > 0]\n",
    "        return -np.sum(hist * np.log2(hist))\n",
    "\n",
    "    def __init__(self, data_source, global_sort=False, ascending=True, slice_factor=1):\n",
    "        # some CIFAR-10 magic numbers...\n",
    "        if global_sort:\n",
    "            num_classes = 1\n",
    "            images_per_class = 50000\n",
    "        else:\n",
    "            num_classes = 10\n",
    "            images_per_class = 5000        \n",
    "        \n",
    "        entropies_by_label = {}\n",
    "        source_idx_by_label = {}\n",
    "        pointers_by_label = {}\n",
    "        sorted_idx_by_label = {}\n",
    "        sort_dir_coeff = 1 if ascending else -1\n",
    "                \n",
    "        for i in range(num_classes):\n",
    "            entropies_by_label[i] = np.zeros(images_per_class)\n",
    "            source_idx_by_label[i] = np.zeros(images_per_class)\n",
    "            pointers_by_label[i] = 0\n",
    "        \n",
    "        for i in range(len(data_source)):\n",
    "            if global_sort:\n",
    "                label = 0\n",
    "            else:\n",
    "                label = data_source[i][1]\n",
    "            idx = pointers_by_label[label]\n",
    "            entropies_by_label[label][idx] = sort_dir_coeff * self.entropy(data_source[i])\n",
    "            source_idx_by_label[label][idx] = i\n",
    "            pointers_by_label[label] += 1\n",
    "        \n",
    "        for i in range(num_classes):\n",
    "            sorted_idx_by_label[i] = np.argsort(entropies_by_label[i])\n",
    "        \n",
    "        sliced_images_per_class = int(slice_factor * images_per_class)\n",
    "        \n",
    "        self.grouped_entropy_idx = np.zeros(sliced_images_per_class * num_classes, dtype=int)\n",
    "\n",
    "        for i in range(sliced_images_per_class):\n",
    "            for j in range(num_classes):\n",
    "                sorted_idx = sorted_idx_by_label[j][i]\n",
    "                source_idx = source_idx_by_label[j][sorted_idx]\n",
    "                self.grouped_entropy_idx[i*num_classes + j] = source_idx\n",
    "    \n",
    "    def create_sliced_copy(self, relative_slice_factor=1):\n",
    "        # because the array round-robbins the categories in group sort mode\n",
    "        # we can truncate it the same way as if it was globally sorted\n",
    "        dupe = copy.copy(self)\n",
    "        end_idx = int(relative_slice_factor * len(dupe.grouped_entropy_idx))\n",
    "        dupe.grouped_entropy_idx = dupe.grouped_entropy_idx[:end_idx]\n",
    "        return dupe\n",
    "    \n",
    "    def __iter__(self):\n",
    "        return iter(self.grouped_entropy_idx)\n",
    "    \n",
    "    def __len__(self):\n",
    "        return len(self.grouped_entropy_idx)\n",
    "\n",
    "class ShuffleSampler(torch.utils.data.Sampler):\n",
    "    def __init__(self, length):\n",
    "        self.length = length\n",
    "        self.shuffled_idx = list(range(length))\n",
    "        np.random.shuffle(self.shuffled_idx)\n",
    "    \n",
    "    def __iter__(self):\n",
    "        return iter(self.shuffled_idx)\n",
    "\n",
    "    def __len__(self):\n",
    "        return self.length"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n"
     ]
    }
   ],
   "source": [
    "#\n",
    "# Data Prep\n",
    "#\n",
    "transform = transforms.Compose([\n",
    "    transforms.ToTensor(),\n",
    "    transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5))\n",
    "])\n",
    "\n",
    "trainset = torchvision.datasets.CIFAR10(\n",
    "    root='./data',\n",
    "    train=True,\n",
    "    download=True,\n",
    "    transform=transform\n",
    ")\n",
    "\n",
    "trainloaders = []\n",
    "if ENTROPY_SORT_ENABLED:\n",
    "    base_entropy_sampler = EntropySampler(trainset, not SORT_GROUPED, SORT_ASCENDING, 1)\n",
    "    for slice_factor in TRAINING_SLICE_FACTORS:\n",
    "        entropy_sampler = base_entropy_sampler.create_sliced_copy(slice_factor)\n",
    "        trainloaders.append(torch.utils.data.DataLoader(\n",
    "            trainset,\n",
    "            batch_size=4,\n",
    "            drop_last=True,\n",
    "            num_workers=2,\n",
    "            sampler=entropy_sampler\n",
    "        ))\n",
    "else:\n",
    "    for slice_factor in TRAINING_SLICE_FACTORS:\n",
    "        trainset_length = int(slice_factor * len(trainset))\n",
    "        shuffle_sampler = ShuffleSampler(trainset_length)\n",
    "        trainloaders.append(torch.utils.data.DataLoader(\n",
    "            trainset,\n",
    "            batch_size=4,\n",
    "            drop_last=True,\n",
    "            num_workers=2,\n",
    "            sampler=shuffle_sampler\n",
    "        ))\n",
    "\n",
    "testset = torchvision.datasets.CIFAR10(\n",
    "    root='./data',\n",
    "    train=False,\n",
    "    download=True,\n",
    "    transform=transform\n",
    ")\n",
    "\n",
    "testloader = torch.utils.data.DataLoader(\n",
    "    testset,\n",
    "    batch_size=4,\n",
    "    shuffle=False,\n",
    "    num_workers=2\n",
    ")\n",
    "\n",
    "classes = (\n",
    "    'plane',\n",
    "    'car',\n",
    "    'bird',\n",
    "    'cat',\n",
    "    'deer',\n",
    "    'dog',\n",
    "    'frog',\n",
    "    'horse',\n",
    "    'ship',\n",
    "    'truck'\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "plane   car  bird   cat\n"
     ]
    }
   ],
   "source": [
    "def imshow(img):\n",
    "    img = img / 2 + 0.5\n",
    "    npimg = img.numpy()\n",
    "    plt.imshow(np.transpose(npimg, (1,2,0)))\n",
    "    plt.show()\n",
    "\n",
    "dataiter = iter(trainloaders[0])\n",
    "images, labels = dataiter.next()\n",
    "\n",
    "imshow(torchvision.utils.make_grid(images))\n",
    "print(' '.join('%5s' % classes[labels[j]] for j in range(4)))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The `imshow` output above can be used as a quick sanity check that the entropy sorting is behaving. Entropy sorting in ascending order should cause 4 relatively low information images to be shown. Entropy sorting in descending order should cause 4 relatively high information (i.e. busy looking) images to be shown. If grouped sorting is enabled each image should be from a different class."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#\n",
    "# Network Def\n",
    "#\n",
    "class Net(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(Net, self).__init__()\n",
    "        self.conv1 = nn.Conv2d(3, 6, 5)\n",
    "        self.pool = nn.MaxPool2d(2, 2)\n",
    "        self.conv2 = nn.Conv2d(6, 16, 5)\n",
    "        self.fc1 = nn.Linear(16 * 5 * 5, 120)\n",
    "        self.fc2 = nn.Linear(120, 84)\n",
    "        self.fc3 = nn.Linear(84, 10)\n",
    "        \n",
    "    def forward(self, x):\n",
    "        x = self.pool(F.relu(self.conv1(x)))\n",
    "        x = self.pool(F.relu(self.conv2(x)))\n",
    "        x = x.view(-1, 16 * 5 * 5)\n",
    "        x = F.relu(self.fc1(x))\n",
    "        x = F.relu(self.fc2(x))\n",
    "        x = self.fc3(x)\n",
    "        return x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#\n",
    "# Training\n",
    "#\n",
    "slice_factor_idx = 0\n",
    "loss_curves = []\n",
    "loss_labels = []\n",
    "for trainloader in trainloaders:\n",
    "    net = Net()\n",
    "    criterion = nn.CrossEntropyLoss()\n",
    "    optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)\n",
    "    slice_factor = TRAINING_SLICE_FACTORS[slice_factor_idx]\n",
    "    suffix = get_file_suffix(slice_factor)\n",
    "    running_period = 2000 * slice_factor\n",
    "    loss_samples = []\n",
    "    \n",
    "    print(f'START Training Time [{suffix}]')\n",
    "    \n",
    "    for epoch in range(2):\n",
    "        running_loss = 0.0\n",
    "    \n",
    "        for i, data in enumerate(trainloader, 0):\n",
    "            inputs, labels = data\n",
    "\n",
    "            optimizer.zero_grad()\n",
    "\n",
    "            outputs = net(inputs)\n",
    "            loss = criterion(outputs, labels)\n",
    "            loss.backward()\n",
    "            optimizer.step()\n",
    "\n",
    "            running_loss += loss.item()\n",
    "            if i % running_period == running_period-1:\n",
    "                curr_loss = running_loss / running_period\n",
    "                loss_samples.append(curr_loss)\n",
    "                print('[%d, %5d] loss: %.3f' % (epoch + 1, i + 1, curr_loss))\n",
    "                running_loss = 0.0\n",
    "    \n",
    "    \n",
    "    PATH = f'./training/cifar10_training_order-{suffix}.net'\n",
    "    torch.save(net.state_dict(), PATH)\n",
    "    \n",
    "    plt.plot(np.linspace(0, 2, len(loss_samples)), loss_samples)\n",
    "    plt.xlabel('Epoch')\n",
    "    plt.ylabel('Loss')\n",
    "    plt.title(f'Training Loss ({suffix})')\n",
    "    plt.show()\n",
    "\n",
    "    loss_curves.append(loss_samples)\n",
    "    loss_labels.append(suffix)\n",
    "    \n",
    "    print(f'END [{suffix}]')\n",
    "    \n",
    "    slice_factor_idx += 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#\n",
    "# Plot Losses\n",
    "#\n",
    "for i in range(len(loss_labels)):\n",
    "    suffix = loss_labels[i]\n",
    "    loss_samples = loss_curves[i]\n",
    "    plt.plot(np.linspace(0, 2, len(loss_samples)), loss_samples, label=suffix)\n",
    "    \n",
    "all_suffix = get_file_suffix('all')\n",
    "plt.xlabel('Epoch')\n",
    "plt.ylabel('Loss')\n",
    "plt.ylim(1, 2.5)\n",
    "plt.title(f'Training Losses')\n",
    "plt.legend(bbox_to_anchor=(1.52, 1.0))\n",
    "plt.savefig(f'./plots/cifar10_training_order_losses-{all_suffix}.png', bbox_inches='tight', dpi=300)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#\n",
    "# Test Time\n",
    "#\n",
    "dataiter = iter(testloader)\n",
    "images, labels = dataiter.next()\n",
    "\n",
    "imshow(torchvision.utils.make_grid(images))\n",
    "print('GroundTruth: ', ' '.join('%5s' % classes[labels[j]] for j in range(4)))\n",
    "\n",
    "accuracies = []\n",
    "\n",
    "for slice_factor in TRAINING_SLICE_FACTORS:\n",
    "    suffix = get_file_suffix(slice_factor)\n",
    "    print(f'START Test Time [{suffix}]')\n",
    "    \n",
    "    net = Net()\n",
    "    net_path = f'./training/cifar10_training_order-{suffix}.net'\n",
    "    net.load_state_dict(torch.load(net_path))\n",
    "\n",
    "    outputs = net(images)\n",
    "\n",
    "    correct = 0\n",
    "    total = 0\n",
    "    with torch.no_grad():\n",
    "        for data in testloader:\n",
    "            images, labels = data\n",
    "            outputs = net(images)\n",
    "            _, predicted = torch.max(outputs.data, 1)\n",
    "            total += labels.size(0)\n",
    "            correct += (predicted == labels).sum().item()\n",
    "\n",
    "    accuracy = 100 * correct / total\n",
    "    print('Accuracy of the network on the 10000 test images: %d %%' % accuracy)\n",
    "    accuracies.append(accuracy)\n",
    "\n",
    "    class_correct = list(0. for i in range(10))\n",
    "    class_total = list(0. for i in range(10))\n",
    "    with torch.no_grad():\n",
    "        for data in testloader:\n",
    "            images, labels = data\n",
    "            outputs = net(images)\n",
    "            _, predicted = torch.max(outputs, 1)\n",
    "            c = (predicted == labels).squeeze()\n",
    "            for i in range(4):\n",
    "                label = labels[i]\n",
    "                class_correct[label] += c[i].item()\n",
    "                class_total[label] += 1\n",
    "\n",
    "    for i in range(10):\n",
    "        print('Accuracy of %5s : %2d %%' % (classes[i], 100 * class_correct[i] / class_total[i]))\n",
    "        \n",
    "    print(f'END [{suffix}]')\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#\n",
    "# Plot Accuracy vs. Training Set Size\n",
    "#\n",
    "training_set_sizes = np.array(TRAINING_SLICE_FACTORS) * 50000\n",
    "all_suffix = get_file_suffix('all')\n",
    "plt.plot(training_set_sizes, accuracies, '-o', label=all_suffix)\n",
    "plt.xlabel('Number of Training Images')\n",
    "plt.ylabel('Accuracy (%)')\n",
    "plt.ylim(20, 65)\n",
    "plt.title(f'Test Accuracy vs. Training Set Size')\n",
    "plt.legend(loc=4)\n",
    "plt.savefig(f'./plots/cifar10_training_order_accuracies-{all_suffix}.png', bbox_inches='tight', dpi=300)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}

A  => cifar10_training_order_test_only.ipynb +441 -0
@@ 1,441 @@
{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "#\n",
    "# Experiment Controls\n",
    "#\n",
    "\n",
    "ENTROPY_SORT_ENABLED = True\n",
    "SORT_GROUPED = True\n",
    "SORT_ASCENDING = False\n",
    "TRAINING_SLICE_FACTORS = [1, 0.9, 0.8, 0.7, 0.6, 0.5, 0.4, 0.3, 0.2, 0.1]\n",
    "\n",
    "def get_file_suffix(slice_factor):\n",
    "    if ENTROPY_SORT_ENABLED:\n",
    "        grouped_str = 'grouped' if SORT_GROUPED else 'global'\n",
    "        ascending_str = 'asc' if SORT_ASCENDING else 'desc'\n",
    "        return f'entropy-{grouped_str}-{ascending_str}-{slice_factor}'\n",
    "    else:\n",
    "        return f'random-{slice_factor}'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import torch.optim as optim\n",
    "import torchvision\n",
    "import torchvision.transforms as transforms\n",
    "import copy\n",
    "\n",
    "class EntropySampler(torch.utils.data.Sampler):\n",
    "    def entropy(self, datum):\n",
    "        arr = datum[0].numpy()\n",
    "        # grayscale coefficients from matlab rgb2gray\n",
    "        arr = 0.2989 * arr[0,:,:] + 0.5870 * arr[1,:,:] + 0.1140 * arr[2,:,:]\n",
    "        hist = np.histogram(arr, bins=256)[0]\n",
    "        hist = hist / np.sum(hist)\n",
    "        hist = hist[hist > 0]\n",
    "        return -np.sum(hist * np.log2(hist))\n",
    "\n",
    "    def __init__(self, data_source, global_sort=False, ascending=True, slice_factor=1):\n",
    "        # some CIFAR-10 magic numbers...\n",
    "        if global_sort:\n",
    "            num_classes = 1\n",
    "            images_per_class = 50000\n",
    "        else:\n",
    "            num_classes = 10\n",
    "            images_per_class = 5000        \n",
    "        \n",
    "        entropies_by_label = {}\n",
    "        source_idx_by_label = {}\n",
    "        pointers_by_label = {}\n",
    "        sorted_idx_by_label = {}\n",
    "        sort_dir_coeff = 1 if ascending else -1\n",
    "                \n",
    "        for i in range(num_classes):\n",
    "            entropies_by_label[i] = np.zeros(images_per_class)\n",
    "            source_idx_by_label[i] = np.zeros(images_per_class)\n",
    "            pointers_by_label[i] = 0\n",
    "        \n",
    "        for i in range(len(data_source)):\n",
    "            if global_sort:\n",
    "                label = 0\n",
    "            else:\n",
    "                label = data_source[i][1]\n",
    "            idx = pointers_by_label[label]\n",
    "            entropies_by_label[label][idx] = sort_dir_coeff * self.entropy(data_source[i])\n",
    "            source_idx_by_label[label][idx] = i\n",
    "            pointers_by_label[label] += 1\n",
    "        \n",
    "        for i in range(num_classes):\n",
    "            sorted_idx_by_label[i] = np.argsort(entropies_by_label[i])\n",
    "        \n",
    "        sliced_images_per_class = int(slice_factor * images_per_class)\n",
    "        \n",
    "        self.grouped_entropy_idx = np.zeros(sliced_images_per_class * num_classes, dtype=int)\n",
    "\n",
    "        for i in range(sliced_images_per_class):\n",
    "            for j in range(num_classes):\n",
    "                sorted_idx = sorted_idx_by_label[j][i]\n",
    "                source_idx = source_idx_by_label[j][sorted_idx]\n",
    "                self.grouped_entropy_idx[i*num_classes + j] = source_idx\n",
    "    \n",
    "    def create_sliced_copy(self, relative_slice_factor=1):\n",
    "        # because the array round-robbins the categories in group sort mode\n",
    "        # we can truncate it the same way as if it was globally sorted\n",
    "        dupe = copy.copy(self)\n",
    "        end_idx = int(relative_slice_factor * len(dupe.grouped_entropy_idx))\n",
    "        dupe.grouped_entropy_idx = dupe.grouped_entropy_idx[:end_idx]\n",
    "        return dupe\n",
    "    \n",
    "    def __iter__(self):\n",
    "        return iter(self.grouped_entropy_idx)\n",
    "    \n",
    "    def __len__(self):\n",
    "        return len(self.grouped_entropy_idx)\n",
    "\n",
    "class ShuffleSampler(torch.utils.data.Sampler):\n",
    "    def __init__(self, length):\n",
    "        self.length = length\n",
    "        self.shuffled_idx = list(range(length))\n",
    "        np.random.shuffle(self.shuffled_idx)\n",
    "    \n",
    "    def __iter__(self):\n",
    "        return iter(self.shuffled_idx)\n",
    "\n",
    "    def __len__(self):\n",
    "        return self.length"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n"
     ]
    }
   ],
   "source": [
    "#\n",
    "# Data Prep\n",
    "#\n",
    "transform = transforms.Compose([\n",
    "    transforms.ToTensor(),\n",
    "    transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5))\n",
    "])\n",
    "\n",
    "trainset = torchvision.datasets.CIFAR10(\n",
    "    root='./data',\n",
    "    train=True,\n",
    "    download=True,\n",
    "    transform=transform\n",
    ")\n",
    "\n",
    "trainloaders = []\n",
    "if ENTROPY_SORT_ENABLED:\n",
    "    base_entropy_sampler = EntropySampler(trainset, not SORT_GROUPED, SORT_ASCENDING, 1)\n",
    "    for slice_factor in TRAINING_SLICE_FACTORS:\n",
    "        entropy_sampler = base_entropy_sampler.create_sliced_copy(slice_factor)\n",
    "        trainloaders.append(torch.utils.data.DataLoader(\n",
    "            trainset,\n",
    "            batch_size=4,\n",
    "            drop_last=True,\n",
    "            num_workers=2,\n",
    "            sampler=entropy_sampler\n",
    "        ))\n",
    "else:\n",
    "    for slice_factor in TRAINING_SLICE_FACTORS:\n",
    "        trainset_length = int(slice_factor * len(trainset))\n",
    "        shuffle_sampler = ShuffleSampler(trainset_length)\n",
    "        trainloaders.append(torch.utils.data.DataLoader(\n",
    "            trainset,\n",
    "            batch_size=4,\n",
    "            drop_last=True,\n",
    "            num_workers=2,\n",
    "            sampler=shuffle_sampler\n",
    "        ))\n",
    "\n",
    "testset = torchvision.datasets.CIFAR10(\n",
    "    root='./data',\n",
    "    train=False,\n",
    "    download=True,\n",
    "    transform=transform\n",
    ")\n",
    "\n",
    "testloader = torch.utils.data.DataLoader(\n",
    "    testset,\n",
    "    batch_size=4,\n",
    "    shuffle=False,\n",
    "    num_workers=2\n",
    ")\n",
    "\n",
    "classes = (\n",
    "    'plane',\n",
    "    'car',\n",
    "    'bird',\n",
    "    'cat',\n",
    "    'deer',\n",
    "    'dog',\n",
    "    'frog',\n",
    "    'horse',\n",
    "    'ship',\n",
    "    'truck'\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "#\n",
    "# Network Def\n",
    "#\n",
    "class Net(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(Net, self).__init__()\n",
    "        self.conv1 = nn.Conv2d(3, 6, 5)\n",
    "        self.pool = nn.MaxPool2d(2, 2)\n",
    "        self.conv2 = nn.Conv2d(6, 16, 5)\n",
    "        self.fc1 = nn.Linear(16 * 5 * 5, 120)\n",
    "        self.fc2 = nn.Linear(120, 84)\n",
    "        self.fc3 = nn.Linear(84, 10)\n",
    "        \n",
    "    def forward(self, x):\n",
    "        x = self.pool(F.relu(self.conv1(x)))\n",
    "        x = self.pool(F.relu(self.conv2(x)))\n",
    "        x = x.view(-1, 16 * 5 * 5)\n",
    "        x = F.relu(self.fc1(x))\n",
    "        x = F.relu(self.fc2(x))\n",
    "        x = self.fc3(x)\n",
    "        return x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "GroundTruth:    cat  ship  ship plane\n",
      "START Test Time [entropy-grouped-desc-1]\n",
      "Accuracy of the network on the 10000 test images: 36 %\n",
      "Accuracy of plane : 10 %\n",
      "Accuracy of   car : 15 %\n",
      "Accuracy of  bird :  6 %\n",
      "Accuracy of   cat : 26 %\n",
      "Accuracy of  deer : 36 %\n",
      "Accuracy of   dog : 37 %\n",
      "Accuracy of  frog : 50 %\n",
      "Accuracy of horse : 63 %\n",
      "Accuracy of  ship : 29 %\n",
      "Accuracy of truck : 93 %\n",
      "END [entropy-grouped-desc-1]\n",
      "START Test Time [entropy-grouped-desc-0.9]\n",
      "Accuracy of the network on the 10000 test images: 55 %\n",
      "Accuracy of plane : 40 %\n",
      "Accuracy of   car : 67 %\n",
      "Accuracy of  bird : 36 %\n",
      "Accuracy of   cat : 40 %\n",
      "Accuracy of  deer : 41 %\n",
      "Accuracy of   dog : 55 %\n",
      "Accuracy of  frog : 72 %\n",
      "Accuracy of horse : 59 %\n",
      "Accuracy of  ship : 70 %\n",
      "Accuracy of truck : 71 %\n",
      "END [entropy-grouped-desc-0.9]\n",
      "START Test Time [entropy-grouped-desc-0.8]\n",
      "Accuracy of the network on the 10000 test images: 52 %\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Traceback (most recent call last):\n",
      "  File \"/home/user/anaconda/lib/python3.7/multiprocessing/queues.py\", line 242, in _feed\n",
      "    send_bytes(obj)\n",
      "  File \"/home/user/anaconda/lib/python3.7/multiprocessing/connection.py\", line 200, in send_bytes\n",
      "    self._send_bytes(m[offset:offset + size])\n",
      "Traceback (most recent call last):\n",
      "  File \"/home/user/anaconda/lib/python3.7/multiprocessing/queues.py\", line 242, in _feed\n",
      "    send_bytes(obj)\n",
      "  File \"/home/user/anaconda/lib/python3.7/multiprocessing/connection.py\", line 200, in send_bytes\n",
      "    self._send_bytes(m[offset:offset + size])\n",
      "  File \"/home/user/anaconda/lib/python3.7/multiprocessing/connection.py\", line 404, in _send_bytes\n",
      "    self._send(header + buf)\n",
      "  File \"/home/user/anaconda/lib/python3.7/multiprocessing/connection.py\", line 368, in _send\n",
      "    n = write(self._handle, buf)\n",
      "BrokenPipeError: [Errno 32] Broken pipe\n",
      "  File \"/home/user/anaconda/lib/python3.7/multiprocessing/connection.py\", line 404, in _send_bytes\n",
      "    self._send(header + buf)\n",
      "  File \"/home/user/anaconda/lib/python3.7/multiprocessing/connection.py\", line 368, in _send\n",
      "    n = write(self._handle, buf)\n",
      "BrokenPipeError: [Errno 32] Broken pipe\n"
     ]
    },
    {
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-5-f7cc59df6325>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m     45\u001b[0m         \u001b[0;32mfor\u001b[0m \u001b[0mdata\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mtestloader\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     46\u001b[0m             \u001b[0mimages\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlabels\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mdata\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 47\u001b[0;31m             \u001b[0moutputs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnet\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mimages\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     48\u001b[0m             \u001b[0m_\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mpredicted\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmax\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0moutputs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     49\u001b[0m             \u001b[0mc\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mpredicted\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0mlabels\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msqueeze\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/anaconda/lib/python3.7/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m    530\u001b[0m             \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_slow_forward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    531\u001b[0m         \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 532\u001b[0;31m             \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    533\u001b[0m         \u001b[0;32mfor\u001b[0m \u001b[0mhook\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_forward_hooks\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    534\u001b[0m             \u001b[0mhook_result\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mhook\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mresult\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m<ipython-input-4-1ca181200171>\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, x)\u001b[0m\n\u001b[1;32m     13\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     14\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 15\u001b[0;31m         \u001b[0mx\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpool\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mF\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrelu\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mconv1\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     16\u001b[0m         \u001b[0mx\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpool\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mF\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrelu\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mconv2\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     17\u001b[0m         \u001b[0mx\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mview\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m-\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m16\u001b[0m \u001b[0;34m*\u001b[0m \u001b[0;36m5\u001b[0m \u001b[0;34m*\u001b[0m \u001b[0;36m5\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/anaconda/lib/python3.7/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m    530\u001b[0m             \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_slow_forward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    531\u001b[0m         \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 532\u001b[0;31m             \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    533\u001b[0m         \u001b[0;32mfor\u001b[0m \u001b[0mhook\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_forward_hooks\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    534\u001b[0m             \u001b[0mhook_result\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mhook\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mresult\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/anaconda/lib/python3.7/site-packages/torch/nn/modules/conv.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, input)\u001b[0m\n\u001b[1;32m    343\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    344\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 345\u001b[0;31m         \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mconv2d_forward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mweight\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    346\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    347\u001b[0m \u001b[0;32mclass\u001b[0m \u001b[0mConv3d\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0m_ConvNd\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/anaconda/lib/python3.7/site-packages/torch/nn/modules/conv.py\u001b[0m in \u001b[0;36mconv2d_forward\u001b[0;34m(self, input, weight)\u001b[0m\n\u001b[1;32m    340\u001b[0m                             _pair(0), self.dilation, self.groups)\n\u001b[1;32m    341\u001b[0m         return F.conv2d(input, weight, self.bias, self.stride,\n\u001b[0;32m--> 342\u001b[0;31m                         self.padding, self.dilation, self.groups)\n\u001b[0m\u001b[1;32m    343\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    344\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
     ]
    }
   ],
   "source": [
    "#\n",
    "# Test Time\n",
    "#\n",
    "def imshow(img):\n",
    "    img = img / 2 + 0.5\n",
    "    npimg = img.numpy()\n",
    "    plt.imshow(np.transpose(npimg, (1,2,0)))\n",
    "    plt.show()\n",
    "\n",
    "dataiter = iter(testloader)\n",
    "images, labels = dataiter.next()\n",
    "\n",
    "imshow(torchvision.utils.make_grid(images))\n",
    "print('GroundTruth: ', ' '.join('%5s' % classes[labels[j]] for j in range(4)))\n",
    "\n",
    "accuracies = []\n",
    "\n",
    "for slice_factor in TRAINING_SLICE_FACTORS:\n",
    "    suffix = get_file_suffix(slice_factor)\n",
    "    print(f'START Test Time [{suffix}]')\n",
    "    \n",
    "    net = Net()\n",
    "    net_path = f'./training/cifar10_training_order-{suffix}.net'\n",
    "    net.load_state_dict(torch.load(net_path))\n",
    "\n",
    "    outputs = net(images)\n",
    "\n",
    "    correct = 0\n",
    "    total = 0\n",
    "    with torch.no_grad():\n",
    "        for data in testloader:\n",
    "            images, labels = data\n",
    "            outputs = net(images)\n",
    "            _, predicted = torch.max(outputs.data, 1)\n",
    "            total += labels.size(0)\n",
    "            correct += (predicted == labels).sum().item()\n",
    "\n",
    "    accuracy = 100 * correct / total\n",
    "    print('Accuracy of the network on the 10000 test images: %d %%' % accuracy)\n",
    "    accuracies.append(accuracy)\n",
    "\n",
    "    class_correct = list(0. for i in range(10))\n",
    "    class_total = list(0. for i in range(10))\n",
    "    with torch.no_grad():\n",
    "        for data in testloader:\n",
    "            images, labels = data\n",
    "            outputs = net(images)\n",
    "            _, predicted = torch.max(outputs, 1)\n",
    "            c = (predicted == labels).squeeze()\n",
    "            for i in range(4):\n",
    "                label = labels[i]\n",
    "                class_correct[label] += c[i].item()\n",
    "                class_total[label] += 1\n",
    "\n",
    "    for i in range(10):\n",
    "        print('Accuracy of %5s : %2d %%' % (classes[i], 100 * class_correct[i] / class_total[i]))\n",
    "        \n",
    "    print(f'END [{suffix}]')\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#\n",
    "# Plot Accuracy vs. Training Set Size\n",
    "#\n",
    "training_set_sizes = np.array(TRAINING_SLICE_FACTORS) * 50000\n",
    "all_suffix = get_file_suffix('all')\n",
    "plt.plot(training_set_sizes, accuracies, '-o', label=all_suffix)\n",
    "plt.xlabel('Number of Training Images')\n",
    "plt.ylabel('Accuracy (%)')\n",
    "plt.ylim(20, 65)\n",
    "plt.title(f'Test Accuracy vs. Training Set Size')\n",
    "plt.legend(loc=4)\n",
    "plt.savefig(f'./plots/cifar10_training_order_accuracies-{all_suffix}.png', bbox_inches='tight', dpi=300)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}

A  => plots/cifar10_training_order_accuracies-entropy-global-asc-all.png +0 -0
A  => plots/cifar10_training_order_accuracies-entropy-global-desc-all.png +0 -0
A  => plots/cifar10_training_order_accuracies-entropy-grouped-asc-all.png +0 -0
A  => plots/cifar10_training_order_accuracies-entropy-grouped-desc-all.png +0 -0
A  => plots/cifar10_training_order_accuracies-random-all.png +0 -0
A  => plots/cifar10_training_order_losses-entropy-global-asc-all.png +0 -0
A  => plots/cifar10_training_order_losses-entropy-global-desc-all.png +0 -0
A  => plots/cifar10_training_order_losses-entropy-grouped-asc-all.png +0 -0
A  => plots/cifar10_training_order_losses-entropy-grouped-desc-all.png +0 -0
A  => plots/cifar10_training_order_losses-random-all.png +0 -0
A  => training/cifar10_training_order-entropy-global-asc-0.1.net +0 -0
A  => training/cifar10_training_order-entropy-global-asc-0.2.net +0 -0
A  => training/cifar10_training_order-entropy-global-asc-0.3.net +0 -0
A  => training/cifar10_training_order-entropy-global-asc-0.4.net +0 -0
A  => training/cifar10_training_order-entropy-global-asc-0.5.net +0 -0
A  => training/cifar10_training_order-entropy-global-asc-0.6.net +0 -0
A  => training/cifar10_training_order-entropy-global-asc-0.7.net +0 -0
A  => training/cifar10_training_order-entropy-global-asc-0.8.net +0 -0
A  => training/cifar10_training_order-entropy-global-asc-0.9.net +0 -0
A  => training/cifar10_training_order-entropy-global-asc-1.net +0 -0
A  => training/cifar10_training_order-entropy-global-desc-0.1.net +0 -0
A  => training/cifar10_training_order-entropy-global-desc-0.2.net +0 -0
A  => training/cifar10_training_order-entropy-global-desc-0.3.net +0 -0
A  => training/cifar10_training_order-entropy-global-desc-0.4.net +0 -0
A  => training/cifar10_training_order-entropy-global-desc-0.5.net +0 -0
A  => training/cifar10_training_order-entropy-global-desc-0.6.net +0 -0
A  => training/cifar10_training_order-entropy-global-desc-0.7.net +0 -0
A  => training/cifar10_training_order-entropy-global-desc-0.8.net +0 -0
A  => training/cifar10_training_order-entropy-global-desc-0.9.net +0 -0
A  => training/cifar10_training_order-entropy-global-desc-1.net +0 -0
A  => training/cifar10_training_order-entropy-grouped-asc-0.1.net +0 -0
A  => training/cifar10_training_order-entropy-grouped-asc-0.2.net +0 -0
A  => training/cifar10_training_order-entropy-grouped-asc-0.3.net +0 -0
A  => training/cifar10_training_order-entropy-grouped-asc-0.4.net +0 -0
A  => training/cifar10_training_order-entropy-grouped-asc-0.5.net +0 -0
A  => training/cifar10_training_order-entropy-grouped-asc-0.6.net +0 -0
A  => training/cifar10_training_order-entropy-grouped-asc-0.7.net +0 -0
A  => training/cifar10_training_order-entropy-grouped-asc-0.8.net +0 -0
A  => training/cifar10_training_order-entropy-grouped-asc-0.9.net +0 -0
A  => training/cifar10_training_order-entropy-grouped-asc-1.net +0 -0
A  => training/cifar10_training_order-entropy-grouped-desc-0.1.net +0 -0
A  => training/cifar10_training_order-entropy-grouped-desc-0.2.net +0 -0
A  => training/cifar10_training_order-entropy-grouped-desc-0.3.net +0 -0
A  => training/cifar10_training_order-entropy-grouped-desc-0.4.net +0 -0
A  => training/cifar10_training_order-entropy-grouped-desc-0.5.net +0 -0
A  => training/cifar10_training_order-entropy-grouped-desc-0.6.net +0 -0
A  => training/cifar10_training_order-entropy-grouped-desc-0.7.net +0 -0
A  => training/cifar10_training_order-entropy-grouped-desc-0.8.net +0 -0
A  => training/cifar10_training_order-entropy-grouped-desc-0.9.net +0 -0
A  => training/cifar10_training_order-entropy-grouped-desc-1.net +0 -0
A  => training/cifar10_training_order-random-0.1.net +0 -0
A  => training/cifar10_training_order-random-0.2.net +0 -0
A  => training/cifar10_training_order-random-0.3.net +0 -0
A  => training/cifar10_training_order-random-0.4.net +0 -0
A  => training/cifar10_training_order-random-0.5.net +0 -0
A  => training/cifar10_training_order-random-0.6.net +0 -0
A  => training/cifar10_training_order-random-0.7.net +0 -0
A  => training/cifar10_training_order-random-0.8.net +0 -0
A  => training/cifar10_training_order-random-0.9.net +0 -0
A  => training/cifar10_training_order-random-1.net +0 -0