You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
468 lines
335 KiB
468 lines
335 KiB
2 years ago
|
{
|
||
|
"cells": [
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 14,
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"from __future__ import print_function, division\n",
|
||
|
"\n",
|
||
|
"import torch\n",
|
||
|
"import torch.nn as nn\n",
|
||
|
"import torch.optim as optim\n",
|
||
|
"from torch.optim import lr_scheduler\n",
|
||
|
"\n",
|
||
|
"import torchvision\n",
|
||
|
"import torchvision.transforms as transforms\n",
|
||
|
"from torchvision import datasets, models\n",
|
||
|
"\n",
|
||
|
"import os\n",
|
||
|
"import time\n",
|
||
|
"import copy\n",
|
||
|
"import pickle\n",
|
||
|
"\n",
|
||
|
"import pandas as pd\n",
|
||
|
"import matplotlib.pylab as plt\n",
|
||
|
"import numpy as np\n",
|
||
|
"from PIL import Image\n",
|
||
|
"import glob\n",
|
||
|
"import random"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 2,
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"name": "stdout",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"Device:: cuda:0\n"
|
||
|
]
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"# Setup the device to run the computations\n",
|
||
|
"device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
|
||
|
"print('Device::', device)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 38,
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"# Script runtime options\n",
|
||
|
"model_name = 'resnet152'\n",
|
||
|
"model_func = models.resnet152\n",
|
||
|
"root_dir = '../data'\n",
|
||
|
"data_dir = os.path.join(root_dir,'species_dataset')\n",
|
||
|
"working_dir = os.path.join('models/classification', model_name)\n",
|
||
|
"batch_size = 16\n",
|
||
|
"num_workers = 4\n",
|
||
|
"num_epochs = 40"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 4,
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"# Setup the model and optimiser\n",
|
||
|
"model_ft = models.resnet152(pretrained=True)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 5,
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"data": {
|
||
|
"text/html": [
|
||
|
"<div>\n",
|
||
|
"<style scoped>\n",
|
||
|
" .dataframe tbody tr th:only-of-type {\n",
|
||
|
" vertical-align: middle;\n",
|
||
|
" }\n",
|
||
|
"\n",
|
||
|
" .dataframe tbody tr th {\n",
|
||
|
" vertical-align: top;\n",
|
||
|
" }\n",
|
||
|
"\n",
|
||
|
" .dataframe thead th {\n",
|
||
|
" text-align: right;\n",
|
||
|
" }\n",
|
||
|
"</style>\n",
|
||
|
"<table border=\"1\" class=\"dataframe\">\n",
|
||
|
" <thead>\n",
|
||
|
" <tr style=\"text-align: right;\">\n",
|
||
|
" <th></th>\n",
|
||
|
" <th>sp_code</th>\n",
|
||
|
" <th>sp_latin</th>\n",
|
||
|
" <th>sp_fr</th>\n",
|
||
|
" </tr>\n",
|
||
|
" </thead>\n",
|
||
|
" <tbody>\n",
|
||
|
" <tr>\n",
|
||
|
" <th>0</th>\n",
|
||
|
" <td>cyacae</td>\n",
|
||
|
" <td>Cyanistes caeruleus</td>\n",
|
||
|
" <td>Mésange bleue</td>\n",
|
||
|
" </tr>\n",
|
||
|
" <tr>\n",
|
||
|
" <th>1</th>\n",
|
||
|
" <td>parmaj</td>\n",
|
||
|
" <td>Parus major</td>\n",
|
||
|
" <td>Mésange charbonnière</td>\n",
|
||
|
" </tr>\n",
|
||
|
" <tr>\n",
|
||
|
" <th>2</th>\n",
|
||
|
" <td>erirub</td>\n",
|
||
|
" <td>Erithacus rubecula</td>\n",
|
||
|
" <td>Rougegorge familier</td>\n",
|
||
|
" </tr>\n",
|
||
|
" <tr>\n",
|
||
|
" <th>3</th>\n",
|
||
|
" <td>prumod</td>\n",
|
||
|
" <td>Prunella modularis</td>\n",
|
||
|
" <td>Accenteur mouchet</td>\n",
|
||
|
" </tr>\n",
|
||
|
" <tr>\n",
|
||
|
" <th>4</th>\n",
|
||
|
" <td>pasdom</td>\n",
|
||
|
" <td>Passer domesticus</td>\n",
|
||
|
" <td>Moineau domestique</td>\n",
|
||
|
" </tr>\n",
|
||
|
" <tr>\n",
|
||
|
" <th>5</th>\n",
|
||
|
" <td>turmer</td>\n",
|
||
|
" <td>Turdus merula</td>\n",
|
||
|
" <td>Merle noir</td>\n",
|
||
|
" </tr>\n",
|
||
|
" <tr>\n",
|
||
|
" <th>6</th>\n",
|
||
|
" <td>felcat</td>\n",
|
||
|
" <td>Felix catus</td>\n",
|
||
|
" <td>Chat domestique</td>\n",
|
||
|
" </tr>\n",
|
||
|
" <tr>\n",
|
||
|
" <th>7</th>\n",
|
||
|
" <td>fricoe</td>\n",
|
||
|
" <td>Fringilla coelebs</td>\n",
|
||
|
" <td>Pinson des arbres</td>\n",
|
||
|
" </tr>\n",
|
||
|
" <tr>\n",
|
||
|
" <th>8</th>\n",
|
||
|
" <td>stedec</td>\n",
|
||
|
" <td>Streptopelia decaocto</td>\n",
|
||
|
" <td>Tourterelle turque</td>\n",
|
||
|
" </tr>\n",
|
||
|
" <tr>\n",
|
||
|
" <th>9</th>\n",
|
||
|
" <td>carcar</td>\n",
|
||
|
" <td>Carduelis carduelis</td>\n",
|
||
|
" <td>Chardonneret élégant</td>\n",
|
||
|
" </tr>\n",
|
||
|
" </tbody>\n",
|
||
|
"</table>\n",
|
||
|
"</div>"
|
||
|
],
|
||
|
"text/plain": [
|
||
|
" sp_code sp_latin sp_fr\n",
|
||
|
"0 cyacae Cyanistes caeruleus Mésange bleue\n",
|
||
|
"1 parmaj Parus major Mésange charbonnière\n",
|
||
|
"2 erirub Erithacus rubecula Rougegorge familier\n",
|
||
|
"3 prumod Prunella modularis Accenteur mouchet\n",
|
||
|
"4 pasdom Passer domesticus Moineau domestique\n",
|
||
|
"5 turmer Turdus merula Merle noir\n",
|
||
|
"6 felcat Felix catus Chat domestique\n",
|
||
|
"7 fricoe Fringilla coelebs Pinson des arbres\n",
|
||
|
"8 stedec Streptopelia decaocto Tourterelle turque\n",
|
||
|
"9 carcar Carduelis carduelis Chardonneret élégant"
|
||
|
]
|
||
|
},
|
||
|
"execution_count": 5,
|
||
|
"metadata": {},
|
||
|
"output_type": "execute_result"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"sp_names = pd.read_csv('/home/ortion/Documents/projects/TensorBird/data/sp_names.csv')\n",
|
||
|
"class_names = sp_names[\"sp_code\"]\n",
|
||
|
"sp_names\n"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 44,
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"# Setup transformations\n",
|
||
|
"data_transform = transforms.Compose([\n",
|
||
|
" transforms.RandomSizedCrop(224),\n",
|
||
|
" transforms.RandomHorizontalFlip(),\n",
|
||
|
" transforms.ToTensor(),\n",
|
||
|
" transforms.Normalize(mean=[0.485, 0.456, 0.406],\n",
|
||
|
" std=[0.229, 0.224, 0.225])\n",
|
||
|
" ])\n",
|
||
|
"# Setup data loaders with augmentation transforms\n",
|
||
|
"image_datasets = {x: datasets.ImageFolder(data_dir)\n",
|
||
|
" for x in ['train', 'test']}\n",
|
||
|
"dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=batch_size,\n",
|
||
|
" shuffle=True, num_workers=num_workers)\n",
|
||
|
" for x in ['train', 'test']}"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 6,
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"num_ftrs = model_ft.fc.in_features\n",
|
||
|
"# Here the size of each output sample is set to 2.\n",
|
||
|
"# Alternatively, it can be generalized to nn.Linear(num_ftrs, len(class_names)).\n",
|
||
|
"model_ft.fc = nn.Linear(num_ftrs, len(class_names))"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 7,
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"model_ft = model_ft.to(device)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 8,
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"criterion = nn.CrossEntropyLoss()\n",
|
||
|
"\n",
|
||
|
"# Observe that all parameters are being optimized\n",
|
||
|
"optimizer_ft = optim.SGD(model_ft.parameters(), lr=0.001, momentum=0.9)\n",
|
||
|
"\n",
|
||
|
"# Decay LR by a factor of 0.1 every 7 epochs\n",
|
||
|
"exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft, step_size=7, gamma=0.1)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 9,
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"def load_images():\n",
|
||
|
" images_per_class = 10\n",
|
||
|
" validation_rate = 0.2\n",
|
||
|
" X_dataset = []\n",
|
||
|
" y_dataset = []\n",
|
||
|
" \n",
|
||
|
" train_dataset = {'x': [], 'y': []}\n",
|
||
|
" val_dataset = {'x': [], 'y': []}\n",
|
||
|
" \n",
|
||
|
" for sp_code in class_names:\n",
|
||
|
" image_paths = glob.glob(f\"../data/species_dataset/{sp_code}/*.jpg\")\n",
|
||
|
" for i in range(images_per_class - 1):\n",
|
||
|
" if (i >= len(image_paths)):\n",
|
||
|
" break\n",
|
||
|
" im = Image.open(image_paths[i])\n",
|
||
|
" im = im.resize((100, 75))\n",
|
||
|
" X_dataset.append(np.array(im))\n",
|
||
|
" y_dataset.append(sp_code)\n",
|
||
|
" \n",
|
||
|
" for i in range(int(len(X_dataset) - len(X_dataset) * validation_rate)):\n",
|
||
|
" train_dataset['x'].append(X_dataset[i])\n",
|
||
|
" train_dataset['y'].append(y_dataset[i])\n",
|
||
|
" \n",
|
||
|
" for i in range(int(len(X_dataset) - len(X_dataset) * validation_rate), len(X_dataset) - 1):\n",
|
||
|
" val_dataset['x'].append(X_dataset[i])\n",
|
||
|
" val_dataset['y'].append(y_dataset[i])\n",
|
||
|
" return (train_dataset, val_dataset)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 10,
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"(train_dataset, val_dataset) = load_images()"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"### Show dataset sample using matplotlib"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 26,
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"data": {
|
||
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAlAAAAIqCAYAAADrW3TiAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8QVMy6AAAACXBIWXMAAAsTAAALEwEAmpwYAAEAAElEQVR4nOz9eaxty7ofBv2+qhpjNqvZ+3T33Hvuu+/eF7cxUYjBxhGRSJSQBmywMGAIkUkkx09AULASCRsTKVZEGiQUDIoU4gSiQBLZiR0SAxFBBEKXzjhxZOJnW8/vvneb0+92rTXnHGNU1ccf3/dV1RhzrrX3ae4++7y7vnvX2XOOOUaNGjW+vitiZtzDPdzDPdzDPdzDPdzDy4P7qidwD/dwD/dwD/dwD/fwdYN7Beoe7uEe7uEe7uEe7uEzwr0CdQ/3cA/3cA/3cA/38BnhXoG6h3u4h3u4h3u4h3v4jHCvQN3DPdzDPdzDPdzDPXxGuFeg7uEe7uEe7uEe7uEePiPcK1CvORDRHyKif+YLXP+Hieif/zLndA/38FXBPT3cwz28PBDR94iIiSh81XP51Qj3i/qaAzP/I1/1HO7hHl4XuKeHe7iHe3hd4N4D9RrDi6yGe6viHn6a4J4e7uEe7uF1gnsF6isAInqPiP4kEX1CRN8nor9Xj/9hIvoTRPTPE9FzAH9XG3Jo3LG/l4h+AOD/RkR/AxH9aDH+LxPRf745tCaiP05EV0T0HxDRf/KVPew93MML4J4e7uGnFRQ3/0dE9OeJ6AkR/bNEtCaiN4jo/6g08UQ//0xz3d9FRL+kOPx9Ivo79Lgnov8ZEX1KRL8E4Lcv7vceEf0pInpMRL9IRL+v+e0PE9G/rPR2RUR/joh+vc7vYyL6IRH9La9scb4GcK9AvWIgIgfg/wDgPwLwbQB/E4DfT0R/q57yOwH8CQAPAfwLtwzz1wP4KwH8rbf8voTfCeBfBvAmgH8RwL9KRN3nmf893MOXCff0cA/3gL8Dgru/BsCvB/APQGTzPwvguwB+FsAewD8BAER0BuB/CeC/wMwXAP6zAP6sjvX7APwOAL8ZwG8B8F9b3OuPAfgRgPf0t3+EiP7G5vf/EoD/HYA3APyHAP4Nncu3AfxDAP6pL+eRf3XAvQL16uG3AniHmf8hZh6Z+ZcA/NMA/pv6+7/DzP8qM2dm3t8yxh9m5ps7fl/Cn2HmP8HME4B/HMAawF/7hZ7iHu7hy4F7eriHn3b4J5j5h8z8GMA/DOBvZ+ZHzPwnmXnHzFd6/K9vrskA/ioi2jDzB8z8H+vx3w3gjzTj/aN2ARF9B8BfB+APMPOBmf8sgH8GwH+7Gff/xcz/BjNHiJHxDoB/TGnljwH4HhE9/AmswdcS7hWoVw/fBfAeET21PwB/CMC7+vsPX2KMlznn5PnMnFEtkHu4h68a7unhHn7aocXfX4HQw5aI/iki+hUNX/8/ATwkIs/MNwD+GwD+OwA+IKL/ExH9Rr3+vRPjofntsSpk7e/fbr5/1HzeA/iUmVPzHQDOP8cz/qqEewXq1cMPAXyfmR82fxfM/F/U3/klxmjPuQGwtS9E5CFWQwvfaX53AH4GwPufa/b3cA9fLtzTwz38tMN3ms8/C8HFvx/AbwDw25j5EsB/Tn8nAFAv0d8M4FsA/gLEawsAH5wYz+B9AG8S0cXi9x9/Sc/xUwf3CtSrh38fwBUR/QEi2mjS319FRL/1c473lyBJsb9d8zj+AQCrxTn/aSL6XVql9PsBDAD+3c/7APdwD18i3NPDPfy0w99DRD9DRG8C+B8D+OMALiAen6d6/B+0k4noXSL6nZoLNQC4hoT0AOBfAvD36nhvAPiDdh0z/xDAvw3gH9VE9b8awO8FcN8X7XPCvQL1ikHdob8DwF8D4PsAPoXEoR98zvGeAfjv6Rg/hljgP1qc9q9BXL5PAPweAL9LY9r3cA9fKdzTwz3cA/5FAP8XAL8E4C8D+J8A+CMANhB6+HcB/J+b8x2Avw/iUXoMyY367+pv/zQk8fs/AvAfAPhXFvf62wF8T6/93wP4B5n5//olP89PDRDzy3jI7+Ee7uEe7uEe7uHLBCL6ZQB/970S8/WEew/UPdzDPdzDPdzDPdzDZ4R7Beoe7uEe7uEe7uEe7uEzwhdSoIjobyOiv6gdTf/gi6+4h3v41Q33NHEP9zCHe5q4HZj5e/fhu68vfO4cKC0P/ksA/mZIkuafhjQA+/Nf3vTu4R6+PnBPE/dwD3O4p4l7+NUMX8QD9Z8B8IvM/EvMPEK6lP7OL2da93APX0u4p4l7uIc53NPEPfyqhS+ye/m3Me94+iMAv+2uC9566y3+2e/+7PEPt3rBGNo37CXAxnjZ8180zqmxPou3jnX2NDt2etzTY8+OnFgjItKfbpmX3YbbdTx17nJeyzX4bF7KF57N8w/cvOdTq/XiwV6wnrwci8EM/PjHH+DJ42dfFGFa+JJp4mXX/WXwtH23t63yqXf92d//XXNhBugzrniLH6d5xW00+9no9ejIC+Z6+11P3Pclp1KuPYnWt72f5gYLUmbWj3Tb0tXZMhjv/+hDPHnymtLEPQhw+c9XDC+Sz0uu+xLDLceeoXdDkLetQUuwn3GdTtMt44c//DEeP35y8kG/iAL1UkBEPw/g5wHgO9/5Gfxb/4//+9E50gomo1U3lg9DM+ZBL1woVgEkzMM1x6OMRm5xzW3KAyrnISrjLp6g4U7Nb0SzcWXXCFWriCAOQNYxE1B+d/UZ7X7k5dwc59MF5HzOuo6G1K48NudUzmTOMg/tzk/ONfdrnyfPGG6d+y3vgOpa5eZ+oOVauTqHehJybsYneT+GAwQHZlkn5307KXBzP3JesIR5Nn5mW2N5jpwzMmf8rv/y342vAl6aJjja+XbURmj+bY/N8W3+vgxFbV1cRS+0NHE02+Ya/c7zu3K5T51PpZPcTIEKTcgjZXCuNGHvvVzPsRnHye8gvV8u5+Zk69SsH0joqcEDJoLTcyqOUsFH27HCOTdfy7KkLU0ojhXGbs/QrCHp6jA3z3mKpVPBWcN4IgLnDLa1J+UZXMeWcY0mWj7DyKmlCT02owlbY6UJpYv/+n/l9x3N7lXAy9DETxfcbdi8XOYNn/484wOGs3qPRr7Nrm5pyWQnc8VLHcvkdOU1Df9pJl+Pitwh5dGMXGU3ebkXJ+EQRf7WAcmsg6WlU3QCxpGDgVr+VBhCoQ+RiUqPOeJv+1v+q7gNvogC9WPMW8b/DE60hGfmPwrgjwLAb/5P/eZbXnvL+E9bkAzWl3ObxrsQMiQM1K6q4Oui36U0nbJ4b7N+j5QnfTEzE7bej+AW3wEu0dTmpc4wOC8OoN6DYxE+gmSmRIgCSSRKS05R8ZUFMcmUtONntTHAWc9zqnSZUDkWuCKETEEURDRErUOTMnWPnKISH8F7UkeZEhKj/FYV62pGi8KlR8nrcwniU0NMRZDCiNLBOTmXlkT3xeFLpAnAFGyBU7hquHKXR6k978RwACobWCpfgBk27QU8o8sl0C24qt/ZjAG7n/4177qCPT+XIZjaoW6hCTCQU8FbowkUBUVwRpRpVb4UN+ZGjyqPSh/MhuOmvDkAqoixKXhu9ghCQ1VBdGQ0MV9HGc8jJaNlAnkP5sbIYptT8689G1qjgVQImEAjEM05Iem5rHTjCj2/7jTx0whfxnKcMAiO7jE3uOpxEnqxkRqcowYfXy7awtUQYm5OVZwHgakxJnJDb+SqUUX1fvIbypyOFbQTshwmY/QSUyQXMp6cP1bOGvgiOVB/GsCvI6KfI6Iesnv6n/o8A5mIXKo6p86aCYIX4lVlfOXIjFGf0Fr59KGXul25ovl3dvFt91YtfsbA6r/MZqFWRnws0ExxaBQDU4JaZbDxppVxTix7QcZqE9enotvelSG0MGNHTpQ66L8qzKg874k1wHwNWuumPgJXBc+eZbZutpZugVdyT3LuFm/LF4YvkSaUoRzh63zNTuLxYqRj5G0EKd1xH7b1poo6LyKCo99Pz62i4vyes+PGNFucKMRMZZCKSksGfkwTMzqaT+QEHtkwPL+1/jCLJtASVxf
|
||
|
"text/plain": [
|
||
|
"<Figure size 720x720 with 9 Axes>"
|
||
|
]
|
||
|
},
|
||
|
"metadata": {
|
||
|
"needs_background": "light"
|
||
|
},
|
||
|
"output_type": "display_data"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"fig = plt.figure(figsize=(10, 10))\n",
|
||
|
"rows = 3\n",
|
||
|
"cols = 3\n",
|
||
|
"axes=[]\n",
|
||
|
"for a in range(rows*cols):\n",
|
||
|
" idx = random.randint(a, len(train_dataset['x'])-1)\n",
|
||
|
" b = train_dataset['x'][idx]\n",
|
||
|
" axes.append(fig.add_subplot(rows, cols, a+1) )\n",
|
||
|
" subplot_title=(\"Photo\"+str(a))\n",
|
||
|
" axes[-1].set_title(train_dataset['y'][idx]) \n",
|
||
|
" plt.imshow(b)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 36,
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"def train_model(model, criterion, optimizer, scheduler, dataset, device, num_epochs=25,\n",
|
||
|
" return_history=False, log_history=True, working_dir='output'):\n",
|
||
|
" since = time.time()\n",
|
||
|
" best_acc = 0.0\n",
|
||
|
" history = {'epochs': [], 'train_loss': [], 'test_loss': [], 'train_acc': [], 'test_acc': []}\n",
|
||
|
" \n",
|
||
|
" for epochs in range(num_epochs):\n",
|
||
|
" print(f\"Epoch{epochs}/{num_epochs -1}\")\n",
|
||
|
" print('-'*10)\n",
|
||
|
" \n",
|
||
|
" for phase in ['train', 'test']:\n",
|
||
|
" if phase == 'train':\n",
|
||
|
" model.train()\n",
|
||
|
" else:\n",
|
||
|
" model.eval()\n",
|
||
|
" \n",
|
||
|
" running_loss = 0.0\n",
|
||
|
" running_corrects = 0\n",
|
||
|
" \n",
|
||
|
" # Iterate over data.\n",
|
||
|
" for inputs, labels in dataloaders[phase]:\n",
|
||
|
" inputs = inputs.to(device)\n",
|
||
|
" labels = labels.to(device)\n",
|
||
|
"\n",
|
||
|
" # zero the parameter gradients\n",
|
||
|
" optimizer.zero_grad()\n",
|
||
|
"\n",
|
||
|
" # forward\n",
|
||
|
" # track history if only in train\n",
|
||
|
" with torch.set_grad_enabled(phase == 'train'):\n",
|
||
|
" outputs = model(inputs)\n",
|
||
|
" _, preds = torch.max(outputs, 1)\n",
|
||
|
" loss = criterion(outputs, labels)\n",
|
||
|
"\n",
|
||
|
" # backward + optimize only if in training phase\n",
|
||
|
" if phase == 'train':\n",
|
||
|
" loss.backward()\n",
|
||
|
" optimizer.step()\n",
|
||
|
" \n",
|
||
|
" if phase == 'train':\n",
|
||
|
" scheduler.step()\n",
|
||
|
" \n",
|
||
|
" print('{} Loss: {:.4f} Acc: {:.4f}'.format(\n",
|
||
|
" phase, epoch_loss, epoch_acc))\n",
|
||
|
" history['epoch'].append(epoch)\n",
|
||
|
" history[phase+'_loss'].append(epoch_loss)\n",
|
||
|
" history[phase+'_acc'].append(epoch_acc)\n",
|
||
|
"\n",
|
||
|
" time_elapsed = time.time() - since\n",
|
||
|
" print('Training complete in {:.0f}m {:.0f}s'.format(\n",
|
||
|
" time_elapsed // 60, time_elapsed % 60))\n",
|
||
|
" print('Best val Acc: {:4f}'.format(best_acc))\n",
|
||
|
" \n",
|
||
|
" \n",
|
||
|
" if return_history:\n",
|
||
|
" return model, history\n",
|
||
|
" else:\n",
|
||
|
" return model"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 37,
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"name": "stdout",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"Epoch0/39\n",
|
||
|
"----------\n"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"ename": "AttributeError",
|
||
|
"evalue": "'list' object has no attribute 'to'",
|
||
|
"output_type": "error",
|
||
|
"traceback": [
|
||
|
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
||
|
"\u001b[0;31mAttributeError\u001b[0m Traceback (most recent call last)",
|
||
|
"\u001b[0;32m<ipython-input-37-c71c5d0ad91e>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0mdataset\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m{\u001b[0m\u001b[0;34m'train'\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mtrain_dataset\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'test'\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mval_dataset\u001b[0m\u001b[0;34m}\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 2\u001b[0;31m \u001b[0mhistory\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmodel_ft\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtrain_model\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodel_ft\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcriterion\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0moptimizer_ft\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mexp_lr_scheduler\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdataset\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdevice\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnum_epochs\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m40\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
|
||
|
"\u001b[0;32m<ipython-input-36-a0d40fd12a57>\u001b[0m in \u001b[0;36mtrain_model\u001b[0;34m(model, criterion, optimizer, scheduler, dataset, device, num_epochs, return_history, log_history, working_dir)\u001b[0m\n\u001b[1;32m 20\u001b[0m \u001b[0minputs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mdataset\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mphase\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'x'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 21\u001b[0m \u001b[0mlabels\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mdataset\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mphase\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'y'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 22\u001b[0;31m \u001b[0minputs\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mto\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 23\u001b[0m \u001b[0mlabels\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mto\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 24\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
|
||
|
"\u001b[0;31mAttributeError\u001b[0m: 'list' object has no attribute 'to'"
|
||
|
]
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"dataset = {'train': train_dataset, 'test': val_dataset}\n",
|
||
|
"history, model_ft = train_model(model_ft, criterion, optimizer_ft, exp_lr_scheduler, dataloaders, device, num_epochs=40)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": null,
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": []
|
||
|
}
|
||
|
],
|
||
|
"metadata": {
|
||
|
"kernelspec": {
|
||
|
"display_name": "tb-venv",
|
||
|
"language": "python",
|
||
|
"name": "tb-venv"
|
||
|
},
|
||
|
"language_info": {
|
||
|
"codemirror_mode": {
|
||
|
"name": "ipython",
|
||
|
"version": 3
|
||
|
},
|
||
|
"file_extension": ".py",
|
||
|
"mimetype": "text/x-python",
|
||
|
"name": "python",
|
||
|
"nbconvert_exporter": "python",
|
||
|
"pygments_lexer": "ipython3",
|
||
|
"version": "3.8.5"
|
||
|
}
|
||
|
},
|
||
|
"nbformat": 4,
|
||
|
"nbformat_minor": 4
|
||
|
}
|