TensorBird/src/TensorBirdId.ipynb

468 lines
335 KiB
Plaintext
Raw Normal View History

{
"cells": [
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [],
"source": [
"from __future__ import print_function, division\n",
"\n",
"import torch\n",
"import torch.nn as nn\n",
"import torch.optim as optim\n",
"from torch.optim import lr_scheduler\n",
"\n",
"import torchvision\n",
"import torchvision.transforms as transforms\n",
"from torchvision import datasets, models\n",
"\n",
"import os\n",
"import time\n",
"import copy\n",
"import pickle\n",
"\n",
"import pandas as pd\n",
"import matplotlib.pylab as plt\n",
"import numpy as np\n",
"from PIL import Image\n",
"import glob\n",
"import random"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Device:: cuda:0\n"
]
}
],
"source": [
"# Setup the device to run the computations\n",
"device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
"print('Device::', device)"
]
},
{
"cell_type": "code",
"execution_count": 38,
"metadata": {},
"outputs": [],
"source": [
"# Script runtime options\n",
"model_name = 'resnet152'\n",
"model_func = models.resnet152\n",
"root_dir = '../data'\n",
"data_dir = os.path.join(root_dir,'species_dataset')\n",
"working_dir = os.path.join('models/classification', model_name)\n",
"batch_size = 16\n",
"num_workers = 4\n",
"num_epochs = 40"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"# Setup the model and optimiser\n",
"model_ft = models.resnet152(pretrained=True)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>sp_code</th>\n",
" <th>sp_latin</th>\n",
" <th>sp_fr</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>cyacae</td>\n",
" <td>Cyanistes caeruleus</td>\n",
" <td>Mésange bleue</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>parmaj</td>\n",
" <td>Parus major</td>\n",
" <td>Mésange charbonnière</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>erirub</td>\n",
" <td>Erithacus rubecula</td>\n",
" <td>Rougegorge familier</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>prumod</td>\n",
" <td>Prunella modularis</td>\n",
" <td>Accenteur mouchet</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>pasdom</td>\n",
" <td>Passer domesticus</td>\n",
" <td>Moineau domestique</td>\n",
" </tr>\n",
" <tr>\n",
" <th>5</th>\n",
" <td>turmer</td>\n",
" <td>Turdus merula</td>\n",
" <td>Merle noir</td>\n",
" </tr>\n",
" <tr>\n",
" <th>6</th>\n",
" <td>felcat</td>\n",
" <td>Felix catus</td>\n",
" <td>Chat domestique</td>\n",
" </tr>\n",
" <tr>\n",
" <th>7</th>\n",
" <td>fricoe</td>\n",
" <td>Fringilla coelebs</td>\n",
" <td>Pinson des arbres</td>\n",
" </tr>\n",
" <tr>\n",
" <th>8</th>\n",
" <td>stedec</td>\n",
" <td>Streptopelia decaocto</td>\n",
" <td>Tourterelle turque</td>\n",
" </tr>\n",
" <tr>\n",
" <th>9</th>\n",
" <td>carcar</td>\n",
" <td>Carduelis carduelis</td>\n",
" <td>Chardonneret élégant</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" sp_code sp_latin sp_fr\n",
"0 cyacae Cyanistes caeruleus Mésange bleue\n",
"1 parmaj Parus major Mésange charbonnière\n",
"2 erirub Erithacus rubecula Rougegorge familier\n",
"3 prumod Prunella modularis Accenteur mouchet\n",
"4 pasdom Passer domesticus Moineau domestique\n",
"5 turmer Turdus merula Merle noir\n",
"6 felcat Felix catus Chat domestique\n",
"7 fricoe Fringilla coelebs Pinson des arbres\n",
"8 stedec Streptopelia decaocto Tourterelle turque\n",
"9 carcar Carduelis carduelis Chardonneret élégant"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"sp_names = pd.read_csv('/home/ortion/Documents/projects/TensorBird/data/sp_names.csv')\n",
"class_names = sp_names[\"sp_code\"]\n",
"sp_names\n"
]
},
{
"cell_type": "code",
"execution_count": 44,
"metadata": {},
"outputs": [],
"source": [
"# Setup transformations\n",
"data_transform = transforms.Compose([\n",
" transforms.RandomSizedCrop(224),\n",
" transforms.RandomHorizontalFlip(),\n",
" transforms.ToTensor(),\n",
" transforms.Normalize(mean=[0.485, 0.456, 0.406],\n",
" std=[0.229, 0.224, 0.225])\n",
" ])\n",
"# Setup data loaders with augmentation transforms\n",
"image_datasets = {x: datasets.ImageFolder(data_dir)\n",
" for x in ['train', 'test']}\n",
"dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=batch_size,\n",
" shuffle=True, num_workers=num_workers)\n",
" for x in ['train', 'test']}"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"num_ftrs = model_ft.fc.in_features\n",
"# Here the size of each output sample is set to 2.\n",
"# Alternatively, it can be generalized to nn.Linear(num_ftrs, len(class_names)).\n",
"model_ft.fc = nn.Linear(num_ftrs, len(class_names))"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"model_ft = model_ft.to(device)"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"criterion = nn.CrossEntropyLoss()\n",
"\n",
"# Observe that all parameters are being optimized\n",
"optimizer_ft = optim.SGD(model_ft.parameters(), lr=0.001, momentum=0.9)\n",
"\n",
"# Decay LR by a factor of 0.1 every 7 epochs\n",
"exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft, step_size=7, gamma=0.1)"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"def load_images():\n",
" images_per_class = 10\n",
" validation_rate = 0.2\n",
" X_dataset = []\n",
" y_dataset = []\n",
" \n",
" train_dataset = {'x': [], 'y': []}\n",
" val_dataset = {'x': [], 'y': []}\n",
" \n",
" for sp_code in class_names:\n",
" image_paths = glob.glob(f\"../data/species_dataset/{sp_code}/*.jpg\")\n",
" for i in range(images_per_class - 1):\n",
" if (i >= len(image_paths)):\n",
" break\n",
" im = Image.open(image_paths[i])\n",
" im = im.resize((100, 75))\n",
" X_dataset.append(np.array(im))\n",
" y_dataset.append(sp_code)\n",
" \n",
" for i in range(int(len(X_dataset) - len(X_dataset) * validation_rate)):\n",
" train_dataset['x'].append(X_dataset[i])\n",
" train_dataset['y'].append(y_dataset[i])\n",
" \n",
" for i in range(int(len(X_dataset) - len(X_dataset) * validation_rate), len(X_dataset) - 1):\n",
" val_dataset['x'].append(X_dataset[i])\n",
" val_dataset['y'].append(y_dataset[i])\n",
" return (train_dataset, val_dataset)"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"(train_dataset, val_dataset) = load_images()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Show dataset sample using matplotlib"
]
},
{
"cell_type": "code",
"execution_count": 26,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAlAAAAIqCAYAAADrW3TiAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8QVMy6AAAACXBIWXMAAAsTAAALEwEAmpwYAAEAAElEQVR4nOz9eaxty7ofBv2+qhpjNqvZ+3T33Hvuu+/eF7cxUYjBxhGRSJSQBmywMGAIkUkkx09AULASCRsTKVZEGiQUDIoU4gSiQBLZiR0SAxFBBEKXzjhxZOJnW8/vvneb0+92rTXnHGNU1ccf3/dV1RhzrrX3ae4++7y7vnvX2XOOOUaNGjW+vitiZtzDPdzDPdzDPdzDPdzDy4P7qidwD/dwD/dwD/dwD/fwdYN7Beoe7uEe7uEe7uEe7uEzwr0CdQ/3cA/3cA/3cA/38BnhXoG6h3u4h3u4h3u4h3v4jHCvQN3DPdzDPdzDPdzDPXxGuFeg7uEe7uEe7uEe7uEePiPcK1CvORDRHyKif+YLXP+Hieif/zLndA/38FXBPT3cwz28PBDR94iIiSh81XP51Qj3i/qaAzP/I1/1HO7hHl4XuKeHe7iHe3hd4N4D9RrDi6yGe6viHn6a4J4e7uEe7uF1gnsF6isAInqPiP4kEX1CRN8nor9Xj/9hIvoTRPTPE9FzAH9XG3Jo3LG/l4h+AOD/RkR/AxH9aDH+LxPRf745tCaiP05EV0T0HxDRf/KVPew93MML4J4e7uGnFRQ3/0dE9OeJ6AkR/bNEtCaiN4jo/6g08UQ//0xz3d9FRL+kOPx9Ivo79Lgnov8ZEX1KRL8E4Lcv7vceEf0pInpMRL9IRL+v+e0PE9G/rPR2RUR/joh+vc7vYyL6IRH9La9scb4GcK9AvWIgIgfg/wDgPwLwbQB/E4DfT0R/q57yOwH8CQAPAfwLtwzz1wP4KwH8rbf8voTfCeBfBvAmgH8RwL9KRN3nmf893MOXCff0cA/3gL8Dgru/BsCvB/APQGTzPwvguwB+FsAewD8BAER0BuB/CeC/wMwXAP6zAP6sjvX7APwOAL8ZwG8B8F9b3OuPAfgRgPf0t3+EiP7G5vf/EoD/HYA3APyHAP4Nncu3AfxDAP6pL+eRf3XAvQL16uG3AniHmf8hZh6Z+ZcA/NMA/pv6+7/DzP8qM2dm3t8yxh9m5ps7fl/Cn2HmP8HME4B/HMAawF/7hZ7iHu7hy4F7eriHn3b4J5j5h8z8GMA/DOBvZ+ZHzPwnmXnHzFd6/K9vrskA/ioi2jDzB8z8H+vx3w3gjzTj/aN2ARF9B8BfB+APMPOBmf8sgH8GwH+7Gff/xcz/BjNHiJHxDoB/TGnljwH4HhE9/AmswdcS7hWoVw/fBfAeET21PwB/CMC7+vsPX2KMlznn5PnMnFEtkHu4h68a7unhHn7aocXfX4HQw5aI/iki+hUNX/8/ATwkIs/MNwD+GwD+OwA+IKL/ExH9Rr3+vRPjofntsSpk7e/fbr5/1HzeA/iUmVPzHQDOP8cz/qqEewXq1cMPAXyfmR82fxfM/F/U3/klxmjPuQGwtS9E5CFWQwvfaX53AH4GwPufa/b3cA9fLtzTwz38tMN3ms8/C8HFvx/AbwDw25j5EsB/Tn8nAFAv0d8M4FsA/gLEawsAH5wYz+B9AG8S0cXi9x9/Sc/xUwf3CtSrh38fwBUR/QEi2mjS319FRL/1c473lyBJsb9d8zj+AQCrxTn/aSL6XVql9PsBDAD+3c/7APdwD18i3NPDPfy0w99DRD9DRG8C+B8D+OMALiAen6d6/B+0k4noXSL6nZoLNQC4hoT0AOBfAvD36nhvAPiDdh0z/xDAvw3gH9VE9b8awO8FcN8X7XPCvQL1ikHdob8DwF8D4PsAPoXEoR98zvGeAfjv6Rg/hljgP1qc9q9BXL5PAPweAL9LY9r3cA9fKdzTwz3cA/5FAP8XAL8E4C8D+J8A+CMANhB6+HcB/J+b8x2Avw/iUXoMyY367+pv/zQk8fs/AvAfAPhXFvf62wF8T6/93wP4B5n5//olP89PDRDzy3jI7+Ee7uEe7uEe7uHLBCL6ZQB/970S8/WEew/UPdzDPdzDPdzDPdzDZ4R7Beoe7uEe7uEe7uEe7uEzwhdSoIjobyOiv6gdTf/gi6+4h3v41Q33NHEP9zCHe5q4HZj5e/fhu68vfO4cKC0P/ksA/mZIkuafhjQA+/Nf3vTu4R6+PnBPE/dwD3O4p4l7+NUMX8QD9Z8B8IvM/EvMPEK6lP7OL2da93APX0u4p4l7uIc53NPEPfyqhS+ye/m3Me94+iMAv+2uC9566y3+2e/+7PEPt3rBGNo37CXAxnjZ8180zqmxPou3jnX2NDt2etzTY8+OnFgjItKfbpmX3YbbdTx17nJeyzX4bF7KF57N8w/cvOdTq/XiwV6wnrwci8EM/PjHH+DJ42dfFGFa+JJp4mXX/WXwtH23t63yqXf92d//XXNhBugzrniLH6d5xW00+9no9ejIC+Z6+11P3Pclp1KuPYnWt72f5gYLUmbWj3Tb0tXZMhjv/+hDPHnymtLEPQhw+c9XDC+Sz0uu+xLDLceeoXdDkLetQUuwn3GdTtMt44c//DEeP35y8kG/iAL1UkBEPw/g5wHgO9/5Gfxb/4//+9E50gomo1U3lg9DM+ZBL1woVgEkzMM1x6OMRm5xzW3KAyrnISrjLp6g4U7Nb0SzcWXXCFWriCAOQNYxE1B+d/UZ7X7k5dwc59MF5HzOuo6G1K48NudUzmTOMg/tzk/ONfdrnyfPGG6d+y3vgOpa5eZ+oOVauTqHehJybsYneT+GAwQHZlkn5307KXBzP3JesIR5Nn5mW2N5jpwzMmf8rv/y342vAl6aJjja+XbURmj+bY/N8W3+vgxFbV1cRS+0NHE02+Ya/c7zu3K5T51PpZPcTIEKTcgjZXCuNGHvvVzPsRnHye8gvV8u5+Zk69SsH0joqcEDJoLTcyqOUsFH27HCOTdfy7KkLU0ojhXGbs/QrCHp6jA3z3mKpVPBWcN4IgLnDLa1J+UZXMeWcY0mWj7DyKmlCT02owlbY6UJpYv/+n/l9x3N7lXAy9DETxfcbdi8XOYNn/484wOGs3qPRr7Nrm5pyWQnc8VLHcvkdOU1Df9pJl+Pitwh5dGMXGU3ebkXJ+EQRf7WAcmsg6WlU3QCxpGDgVr+VBhCoQ+RiUqPOeJv+1v+q7gNvogC9WPMW8b/DE60hGfmPwrgjwLAb/5P/eZbXnvL+E9bkAzWl3ObxrsQMiQM1K6q4Oui36U0nbJ4b7N+j5QnfTEzE7bej+AW3wEu0dTmpc4wOC8OoN6DYxE+gmSmRIgCSSRKS05R8ZUFMcmUtONntTHAWc9zqnSZUDkWuCKETEEURDRErUOTMnWPnKISH8F7UkeZEhKj/FYV62pGi8KlR8nrcwniU0NMRZDCiNLBOTmXlkT3xeFLpAnAFGyBU7hquHKXR6k978RwACobWCpfgBk27QU8o8sl0C24qt/ZjAG7n/4177qCPT+XIZjaoW6hCTCQU8FbowkUBUVwRpRpVb4UN+ZGjyqPSh/MhuOmvDkAqoixKXhu9ghCQ1VBdGQ0MV9HGc8jJaNlAnkP5sbIYptT8689G1qjgVQImEAjEM05Iem5rHTjCj2/7jTx0whfxnKcMAiO7jE3uOpxEnqxkRqcowYfXy7awtUQYm5OVZwHgakxJnJDb+SqUUX1fvIbypyOFbQTshwmY/QSUyQXMp6cP1bOGvgiOVB/GsCvI6KfI6Iesnv6n/o8A5mIXKo6p86aCYIX4lVlfOXIjFGf0Fr59KGXul25ovl3dvFt91YtfsbA6r/MZqFWRnws0ExxaBQDU4JaZbDxppVxTix7QcZqE9enotvelSG0MGNHTpQ66L8qzKg874k1wHwNWuumPgJXBc+eZbZutpZugVdyT3LuFm/LF4YvkSaUoRzh63zNTuLxYqRj5G0EKd1xH7b1poo6LyKCo99Pz62i4vyes+PGNFucKMRMZZCKSksGfkwTMzqaT+QEHtkwPL+1/jCLJtASVxf
"text/plain": [
"<Figure size 720x720 with 9 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"fig = plt.figure(figsize=(10, 10))\n",
"rows = 3\n",
"cols = 3\n",
"axes=[]\n",
"for a in range(rows*cols):\n",
" idx = random.randint(a, len(train_dataset['x'])-1)\n",
" b = train_dataset['x'][idx]\n",
" axes.append(fig.add_subplot(rows, cols, a+1) )\n",
" subplot_title=(\"Photo\"+str(a))\n",
" axes[-1].set_title(train_dataset['y'][idx]) \n",
" plt.imshow(b)"
]
},
{
"cell_type": "code",
"execution_count": 36,
"metadata": {},
"outputs": [],
"source": [
"def train_model(model, criterion, optimizer, scheduler, dataset, device, num_epochs=25,\n",
" return_history=False, log_history=True, working_dir='output'):\n",
" since = time.time()\n",
" best_acc = 0.0\n",
" history = {'epochs': [], 'train_loss': [], 'test_loss': [], 'train_acc': [], 'test_acc': []}\n",
" \n",
" for epochs in range(num_epochs):\n",
" print(f\"Epoch{epochs}/{num_epochs -1}\")\n",
" print('-'*10)\n",
" \n",
" for phase in ['train', 'test']:\n",
" if phase == 'train':\n",
" model.train()\n",
" else:\n",
" model.eval()\n",
" \n",
" running_loss = 0.0\n",
" running_corrects = 0\n",
" \n",
" # Iterate over data.\n",
" for inputs, labels in dataloaders[phase]:\n",
" inputs = inputs.to(device)\n",
" labels = labels.to(device)\n",
"\n",
" # zero the parameter gradients\n",
" optimizer.zero_grad()\n",
"\n",
" # forward\n",
" # track history if only in train\n",
" with torch.set_grad_enabled(phase == 'train'):\n",
" outputs = model(inputs)\n",
" _, preds = torch.max(outputs, 1)\n",
" loss = criterion(outputs, labels)\n",
"\n",
" # backward + optimize only if in training phase\n",
" if phase == 'train':\n",
" loss.backward()\n",
" optimizer.step()\n",
" \n",
" if phase == 'train':\n",
" scheduler.step()\n",
" \n",
" print('{} Loss: {:.4f} Acc: {:.4f}'.format(\n",
" phase, epoch_loss, epoch_acc))\n",
" history['epoch'].append(epoch)\n",
" history[phase+'_loss'].append(epoch_loss)\n",
" history[phase+'_acc'].append(epoch_acc)\n",
"\n",
" time_elapsed = time.time() - since\n",
" print('Training complete in {:.0f}m {:.0f}s'.format(\n",
" time_elapsed // 60, time_elapsed % 60))\n",
" print('Best val Acc: {:4f}'.format(best_acc))\n",
" \n",
" \n",
" if return_history:\n",
" return model, history\n",
" else:\n",
" return model"
]
},
{
"cell_type": "code",
"execution_count": 37,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch0/39\n",
"----------\n"
]
},
{
"ename": "AttributeError",
"evalue": "'list' object has no attribute 'to'",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mAttributeError\u001b[0m Traceback (most recent call last)",
"\u001b[0;32m<ipython-input-37-c71c5d0ad91e>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0mdataset\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m{\u001b[0m\u001b[0;34m'train'\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mtrain_dataset\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'test'\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mval_dataset\u001b[0m\u001b[0;34m}\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 2\u001b[0;31m \u001b[0mhistory\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmodel_ft\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtrain_model\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodel_ft\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcriterion\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0moptimizer_ft\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mexp_lr_scheduler\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdataset\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdevice\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnum_epochs\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m40\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
"\u001b[0;32m<ipython-input-36-a0d40fd12a57>\u001b[0m in \u001b[0;36mtrain_model\u001b[0;34m(model, criterion, optimizer, scheduler, dataset, device, num_epochs, return_history, log_history, working_dir)\u001b[0m\n\u001b[1;32m 20\u001b[0m \u001b[0minputs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mdataset\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mphase\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'x'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 21\u001b[0m \u001b[0mlabels\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mdataset\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mphase\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'y'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 22\u001b[0;31m \u001b[0minputs\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mto\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 23\u001b[0m \u001b[0mlabels\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mto\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 24\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;31mAttributeError\u001b[0m: 'list' object has no attribute 'to'"
]
}
],
"source": [
"dataset = {'train': train_dataset, 'test': val_dataset}\n",
"history, model_ft = train_model(model_ft, criterion, optimizer_ft, exp_lr_scheduler, dataloaders, device, num_epochs=40)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "tb-venv",
"language": "python",
"name": "tb-venv"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.5"
}
},
"nbformat": 4,
"nbformat_minor": 4
}