diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..3abc738 --- /dev/null +++ b/.gitignore @@ -0,0 +1,2 @@ +records\ +vo\ diff --git a/.ipynb_checkpoints/eval-checkpoint.py b/.ipynb_checkpoints/eval-checkpoint.py new file mode 100644 index 0000000..7896f93 --- /dev/null +++ b/.ipynb_checkpoints/eval-checkpoint.py @@ -0,0 +1,410 @@ +import numpy as np +import torch +import torch.nn.functional as F +from typing import Union +import pandas as pd +import os +from sklearn.metrics import confusion_matrix +import matplotlib.pyplot as plt + + +@torch.no_grad() +def cal_train_metrics(args, msg: dict, outs: dict, labels: torch.Tensor, batch_size: int): + """ + only present top-1 training accuracy + """ + + total_loss = 0.0 + + if args.use_fpn: + for i in range(1, 5): + acc = top_k_corrects(outs["layer"+str(i)].mean(1), labels, tops=[1])["top-1"] / batch_size + acc = round(acc * 100, 2) + msg["train_acc/layer{}_acc".format(i)] = acc + loss = F.cross_entropy(outs["layer"+str(i)].mean(1), labels) + msg["train_loss/layer{}_loss".format(i)] = loss.item() + total_loss += loss.item() + + if args.use_selection: + for name in outs: + if "select_" not in name: + continue + B, S, _ = outs[name].size() + logit = outs[name].view(-1, args.num_classes) + labels_0 = labels.unsqueeze(1).repeat(1, S).flatten(0) + acc = top_k_corrects(logit, labels_0, tops=[1])["top-1"] / (B*S) + acc = round(acc * 100, 2) + msg["train_acc/{}_acc".format(name)] = acc + labels_0 = torch.zeros([B * S, args.num_classes]) - 1 + labels_0 = labels_0.to(args.device) + loss = F.mse_loss(F.tanh(logit), labels_0) + msg["train_loss/{}_loss".format(name)] = loss.item() + total_loss += loss.item() + + for name in outs: + if "drop_" not in name: + continue + B, S, _ = outs[name].size() + logit = outs[name].view(-1, args.num_classes) + labels_1 = labels.unsqueeze(1).repeat(1, S).flatten(0) + acc = top_k_corrects(logit, labels_1, tops=[1])["top-1"] / (B*S) + acc = round(acc * 100, 2) + msg["train_acc/{}_acc".format(name)] = acc + loss = F.cross_entropy(logit, labels_1) + msg["train_loss/{}_loss".format(name)] = loss.item() + total_loss += loss.item() + + if args.use_combiner: + acc = top_k_corrects(outs['comb_outs'], labels, tops=[1])["top-1"] / batch_size + acc = round(acc * 100, 2) + msg["train_acc/combiner_acc"] = acc + loss = F.cross_entropy(outs['comb_outs'], labels) + msg["train_loss/combiner_loss"] = loss.item() + total_loss += loss.item() + + if "ori_out" in outs: + acc = top_k_corrects(outs["ori_out"], labels, tops=[1])["top-1"] / batch_size + acc = round(acc * 100, 2) + msg["train_acc/ori_acc"] = acc + loss = F.cross_entropy(outs["ori_out"], labels) + msg["train_loss/ori_loss"] = loss.item() + total_loss += loss.item() + + msg["train_loss/total_loss"] = total_loss + + + +@torch.no_grad() +def top_k_corrects(preds: torch.Tensor, labels: torch.Tensor, tops: list = [1, 3, 5]): + """ + preds: [B, C] (C is num_classes) + labels: [B, ] + """ + if preds.device != torch.device('cpu'): + preds = preds.cpu() + if labels.device != torch.device('cpu'): + labels = labels.cpu() + tmp_cor = 0 + corrects = {"top-"+str(x):0 for x in tops} + sorted_preds = torch.sort(preds, dim=-1, descending=True)[1] + for i in range(tops[-1]): + tmp_cor += sorted_preds[:, i].eq(labels).sum().item() + # records + if "top-"+str(i+1) in corrects: + corrects["top-"+str(i+1)] = tmp_cor + return corrects + + +@torch.no_grad() +def _cal_evalute_metric(corrects: dict, + total_samples: dict, + logits: torch.Tensor, + labels: torch.Tensor, + this_name: str, + scores: Union[list, None] = None, + score_names: Union[list, None] = None): + + tmp_score = torch.softmax(logits, dim=-1) + tmp_corrects = top_k_corrects(tmp_score, labels, tops=[1, 3]) # return top-1, top-3, top-5 accuracy + + ### each layer's top-1, top-3 accuracy + for name in tmp_corrects: + eval_name = this_name + "-" + name + if eval_name not in corrects: + corrects[eval_name] = 0 + total_samples[eval_name] = 0 + corrects[eval_name] += tmp_corrects[name] + total_samples[eval_name] += labels.size(0) + + if scores is not None: + scores.append(tmp_score) + if score_names is not None: + score_names.append(this_name) + + +@torch.no_grad() +def _average_top_k_result(corrects: dict, total_samples: dict, scores: list, labels: torch.Tensor, tops: list = [1, 2, 3, 4, 5]): + """ + scores is a list contain: + [ + tensor1, + tensor2,... + ] tensor1 and tensor2 have same size [B, num_classes] + """ + # initial + for t in tops: + eval_name = "highest-{}".format(t) + if eval_name not in corrects: + corrects[eval_name] = 0 + total_samples[eval_name] = 0 + total_samples[eval_name] += labels.size(0) + + if labels.device != torch.device('cpu'): + labels = labels.cpu() + + batch_size = labels.size(0) + scores_t = torch.cat([s.unsqueeze(1) for s in scores], dim=1) # B, 5, C + + if scores_t.device != torch.device('cpu'): + scores_t = scores_t.cpu() + + max_scores = torch.max(scores_t, dim=-1)[0] + # sorted_ids = torch.sort(max_scores, dim=-1, descending=True)[1] # this id represents different layers outputs, not samples + + for b in range(batch_size): + tmp_logit = None + ids = torch.sort(max_scores[b], dim=-1)[1] # S + for i in range(tops[-1]): + top_i_id = ids[i] + if tmp_logit is None: + tmp_logit = scores_t[b][top_i_id] + else: + tmp_logit += scores_t[b][top_i_id] + # record results + if i+1 in tops: + if torch.max(tmp_logit, dim=-1)[1] == labels[b]: + eval_name = "highest-{}".format(i+1) + corrects[eval_name] += 1 + + +def evaluate(args, model, test_loader): + """ + [Notice: Costom Model] + If you use costom model, please change fpn module return name (under + if args.use_fpn: ...) + [Evaluation Metrics] + We calculate each layers accuracy, combiner accuracy and average-higest-1 ~ + average-higest-5 accuracy (average-higest-5 means average all predict scores + as final predict) + """ + + model.eval() + corrects = {} + total_samples = {} + + total_batchs = len(test_loader) # just for log + show_progress = [x/10 for x in range(11)] # just for log + progress_i = 0 + + with torch.no_grad(): + """ accumulate """ + for batch_id, (ids, datas, labels) in enumerate(test_loader): + + score_names = [] + scores = [] + datas = datas.to(args.device) + + outs = model(datas) + + if args.use_fpn: + for i in range(1, 5): + this_name = "layer" + str(i) + _cal_evalute_metric(corrects, total_samples, outs[this_name].mean(1), labels, this_name, scores, score_names) + + ### for research + if args.use_selection: + for name in outs: + if "select_" not in name: + continue + this_name = name + S = outs[name].size(1) + logit = outs[name].view(-1, args.num_classes) + labels_1 = labels.unsqueeze(1).repeat(1, S).flatten(0) + _cal_evalute_metric(corrects, total_samples, logit, labels_1, this_name) + + for name in outs: + if "drop_" not in name: + continue + this_name = name + S = outs[name].size(1) + logit = outs[name].view(-1, args.num_classes) + labels_0 = labels.unsqueeze(1).repeat(1, S).flatten(0) + _cal_evalute_metric(corrects, total_samples, logit, labels_0, this_name) + + if args.use_combiner: + this_name = "combiner" + _cal_evalute_metric(corrects, total_samples, outs["comb_outs"], labels, this_name, scores, score_names) + + if "ori_out" in outs: + this_name = "original" + _cal_evalute_metric(corrects, total_samples, outs["ori_out"], labels, this_name) + + _average_top_k_result(corrects, total_samples, scores, labels) + + eval_progress = (batch_id + 1) / total_batchs + + if eval_progress > show_progress[progress_i]: + print(".."+str(int(show_progress[progress_i]*100))+"%", end='', flush=True) + progress_i += 1 + + """ calculate accuracy """ + # total_samples = len(test_loader.dataset) + + best_top1 = 0.0 + best_top1_name = "" + eval_acces = {} + for name in corrects: + acc = corrects[name] / total_samples[name] + acc = round(100 * acc, 3) + eval_acces[name] = acc + ### only compare top-1 accuracy + if "top-1" in name or "highest" in name: + if acc >= best_top1: + best_top1 = acc + best_top1_name = name + + return best_top1, best_top1_name, eval_acces + + +def evaluate_cm(args, model, test_loader): + """ + [Notice: Costom Model] + If you use costom model, please change fpn module return name (under + if args.use_fpn: ...) + [Evaluation Metrics] + We calculate each layers accuracy, combiner accuracy and average-higest-1 ~ + average-higest-5 accuracy (average-higest-5 means average all predict scores + as final predict) + """ + + model.eval() + corrects = {} + total_samples = {} + results = [] + + with torch.no_grad(): + """ accumulate """ + for batch_id, (ids, datas, labels) in enumerate(test_loader): + + score_names = [] + scores = [] + datas = datas.to(args.device) + outs = model(datas) + + # if args.use_fpn and (0 < args.highest < 5): + # this_name = "layer" + str(args.highest) + # _cal_evalute_metric(corrects, total_samples, outs[this_name].mean(1), labels, this_name, scores, score_names) + + if args.use_combiner: + this_name = "combiner" + _cal_evalute_metric(corrects, total_samples, outs["comb_outs"], labels, this_name, scores, score_names) + + # _average_top_k_result(corrects, total_samples, scores, labels) + + for i in range(scores[0].shape[0]): + results.append([test_loader.dataset.data_infos[ids[i].item()]['path'], int(labels[i].item()), + int(scores[0][i].argmax().item()), + scores[0][i][scores[0][i].argmax().item()].item()]) # 图片路径,标签,预测标签,得分 + + """ wirte xlsx""" + writer = pd.ExcelWriter(args.save_dir + 'infer_result.xlsx') + df = pd.DataFrame(results, columns=["id", "original_label", "predict_label", "goal"]) + df.to_excel(writer, index=False, sheet_name="Sheet1") + writer.save() + writer.close() + + """ calculate accuracy """ + + best_top1 = 0.0 + best_top1_name = "" + eval_acces = {} + for name in corrects: + acc = corrects[name] / total_samples[name] + acc = round(100 * acc, 3) + eval_acces[name] = acc + ### only compare top-1 accuracy + if "top-1" in name or "highest" in name: + if acc > best_top1: + best_top1 = acc + best_top1_name = name + + """ wirte xlsx""" + results_mat = np.mat(results) + y_actual = results_mat[:, 1].transpose().tolist()[0] + y_actual = list(map(int, y_actual)) + y_predict = results_mat[:, 2].transpose().tolist()[0] + y_predict = list(map(int, y_predict)) + + folders = os.listdir(args.val_root) + folders.sort() # sort by alphabet + print("[dataset] class:", folders) + df_confusion = confusion_matrix(y_actual, y_predict) + plot_confusion_matrix(df_confusion, folders, args.save_dir + "infer_cm.png", accuracy=best_top1) + + return best_top1, best_top1_name, eval_acces + + +@torch.no_grad() +def eval_and_save(args, model, val_loader, tlogger): + tlogger.print("Start Evaluating") + acc, eval_name, eval_acces = evaluate(args, model, val_loader) + tlogger.print("....BEST_ACC: {} {}%".format(eval_name, acc)) + ### build records.txt + msg = "[Evaluation Results]\n" + msg += "Project: {}, Experiment: {}\n".format(args.project_name, args.exp_name) + msg += "Samples: {}\n".format(len(val_loader.dataset)) + msg += "\n" + for name in eval_acces: + msg += " {} {}%\n".format(name, eval_acces[name]) + msg += "\n" + msg += "BEST_ACC: {} {}% ".format(eval_name, acc) + + with open(args.save_dir + "eval_results.txt", "w") as ftxt: + ftxt.write(msg) + + +@torch.no_grad() +def eval_and_cm(args, model, val_loader, tlogger): + tlogger.print("Start Evaluating") + acc, eval_name, eval_acces = evaluate_cm(args, model, val_loader) + tlogger.print("....BEST_ACC: {} {}%".format(eval_name, acc)) + ### build records.txt + msg = "[Evaluation Results]\n" + msg += "Project: {}, Experiment: {}\n".format(args.project_name, args.exp_name) + msg += "Samples: {}\n".format(len(val_loader.dataset)) + msg += "\n" + for name in eval_acces: + msg += " {} {}%\n".format(name, eval_acces[name]) + msg += "\n" + msg += "BEST_ACC: {} {}% ".format(eval_name, acc) + + with open(args.save_dir + "infer_results.txt", "w") as ftxt: + ftxt.write(msg) + + +def plot_confusion_matrix(cm, label_names, save_name, title='Confusion Matrix acc = ', accuracy=0): + plt.rcParams['font.sans-serif'] = ['SimHei'] + plt.figure(figsize=(len(label_names) / 2, len(label_names) / 2), dpi=100) + np.set_printoptions(precision=2) + # print("cm:\n",cm) + + # 统计混淆矩阵中每格的概率值 + x, y = np.meshgrid(np.arange(len(cm)), np.arange(len(cm))) + for x_val, y_val in zip(x.flatten(), y.flatten()): + try: + c = (cm[y_val][x_val] / np.sum(cm, axis=1)[y_val]) * 100 + except KeyError: + c = 0 + if c > 0.001: + plt.text(x_val, y_val, "%0.1f" % (c,), color='red', fontsize=15, va='center', ha='center') + + plt.imshow(cm, interpolation='nearest', cmap=plt.get_cmap('Blues')) + plt.title(title + str('{:.3f}'.format(accuracy))) + plt.colorbar() + plt.xticks(np.arange(len(label_names)), label_names, rotation=45) + plt.yticks(np.arange(len(label_names)), label_names) + plt.ylabel('Actual label') + plt.xlabel('Predict label') + + # offset the tick + tick_marks = np.array(range(len(label_names))) + 0.5 + plt.gca().set_xticks(tick_marks, minor=True) + plt.gca().set_yticks(tick_marks, minor=True) + plt.gca().xaxis.set_ticks_position('none') + plt.gca().yaxis.set_ticks_position('none') + plt.grid(True, which='minor', linestyle='-') + plt.gcf().subplots_adjust(bottom=0.15) + + # show confusion matrix + plt.savefig(save_name, format='png') + # plt.show() diff --git a/.ipynb_checkpoints/how_to_build_pim_model-checkpoint.ipynb b/.ipynb_checkpoints/how_to_build_pim_model-checkpoint.ipynb new file mode 100644 index 0000000..2aa6850 --- /dev/null +++ b/.ipynb_checkpoints/how_to_build_pim_model-checkpoint.ipynb @@ -0,0 +1,450 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "7dbd109f", + "metadata": {}, + "source": [ + "### Import packages" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "cefb7412", + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "import torch.nn as nn\n", + "from torchvision.models.feature_extraction import get_graph_node_names\n", + "\n", + "from models.pim_module.pim_module import PluginMoodel" + ] + }, + { + "cell_type": "markdown", + "id": "27338b6f", + "metadata": {}, + "source": [ + "### costom model" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "136f4006", + "metadata": {}, + "outputs": [], + "source": [ + "class Model(nn.Module):\n", + " \n", + " def __init__(self):\n", + " super().__init__()\n", + " self.conv1 = nn.Sequential(\n", + " nn.Conv2d(3, 64, 3, padding=1),\n", + " nn.BatchNorm2d(64),\n", + " nn.ReLU(),\n", + " nn.Conv2d(64, 64, 3, stride=2, padding=1),\n", + " nn.BatchNorm2d(64),\n", + " nn.ReLU(),\n", + " )\n", + " self.conv2 = nn.Sequential(\n", + " nn.Conv2d(64, 128, 3, padding=1),\n", + " nn.BatchNorm2d(128),\n", + " nn.ReLU(),\n", + " nn.Conv2d(128, 128, 3, stride=2, padding=1),\n", + " nn.BatchNorm2d(128),\n", + " nn.ReLU()\n", + " )\n", + " self.pool = nn.AdaptiveAvgPool2d((1, 1))\n", + " self.classifier = nn.Linear(128, 10)\n", + " \n", + " def forward(self, x):\n", + " x = self.conv1(x)\n", + " x = self.conv2(x)\n", + " x = self.pool(x)\n", + " x = x.flatten(1)\n", + " x = self.classifier(x)\n", + " return x" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "efcd0eb6", + "metadata": {}, + "outputs": [], + "source": [ + "model = Model()" + ] + }, + { + "cell_type": "markdown", + "id": "0d27d9ca", + "metadata": {}, + "source": [ + "### get model name" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "53d7ff7c", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Model(\n", + " (conv1): Sequential(\n", + " (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", + " (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (2): ReLU()\n", + " (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))\n", + " (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (5): ReLU()\n", + " )\n", + " (conv2): Sequential(\n", + " (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", + " (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (2): ReLU()\n", + " (3): Conv2d(128, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))\n", + " (4): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (5): ReLU()\n", + " )\n", + " (pool): AdaptiveAvgPool2d(output_size=(1, 1))\n", + " (classifier): Linear(in_features=128, out_features=10, bias=True)\n", + ")\n" + ] + } + ], + "source": [ + "print(model) ### structure" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "de5d91fa", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "(['x', 'conv1.0', 'conv1.1', 'conv1.2', 'conv1.3', 'conv1.4', 'conv1.5', 'conv2.0', 'conv2.1', 'conv2.2', 'conv2.3', 'conv2.4', 'conv2.5', 'pool', 'flatten', 'classifier'], ['x', 'conv1.0', 'conv1.1', 'conv1.2', 'conv1.3', 'conv1.4', 'conv1.5', 'conv2.0', 'conv2.1', 'conv2.2', 'conv2.3', 'conv2.4', 'conv2.5', 'pool', 'flatten', 'classifier'])\n" + ] + } + ], + "source": [ + "print(get_graph_node_names(model))" + ] + }, + { + "cell_type": "markdown", + "id": "7b3d4bb1", + "metadata": {}, + "source": [ + "### prepare material to build PluginMoodel" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "4b14b3ff", + "metadata": {}, + "outputs": [], + "source": [ + "# if we want conv1 output and conv2 output\n", + "return_nodes = {\n", + " \"conv1.5\":\"layer1\",\n", + " \"conv2.5\":\"layer2\",\n", + "}" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "b169b258", + "metadata": {}, + "outputs": [], + "source": [ + "# notice that 'layer1' and 'layer2' must match return_nodes's value\n", + "num_selects = {\n", + " \"layer1\":64, \n", + " \"layer2\":64\n", + "}" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "d3a868a3", + "metadata": {}, + "outputs": [], + "source": [ + "IMG_SIZE = 224\n", + "USE_FPN = True\n", + "FPN_SIZE = 128 # fpn projection size, if do not use fpn, you can set fpn_size to None" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "3b71c25b", + "metadata": {}, + "outputs": [], + "source": [ + "# proj_type : you can choose 'Conv' or 'Linear', 'Conv' is design for 4d image input (resnet, efficientnet, vgg...),\n", + "# 'Linear' is for 3d image input (Vit, Swin-T...)\n", + "PROJ_TYPE = \"Conv\"" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "054a3c5b", + "metadata": {}, + "outputs": [], + "source": [ + "# upsample_type : [\"Bilinear\", \"Conv\", \"Fc\"]\n", + "# for convolution neural network (e.g. ResNet, EfficientNet), recommand 'Bilinear'. \n", + "# for Vit, \"Fc\". and Swin-T, \"Conv\"\n", + "UPSAMPLE_TYPE = \"Bilinear\"" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "8313f994", + "metadata": {}, + "outputs": [], + "source": [ + "pim_model = \\\n", + "PluginMoodel(backbone = model,\n", + " return_nodes = return_nodes,\n", + " img_size = IMG_SIZE,\n", + " use_fpn = USE_FPN,\n", + " fpn_size = FPN_SIZE,\n", + " proj_type = PROJ_TYPE,\n", + " upsample_type = UPSAMPLE_TYPE,\n", + " use_selection = True,\n", + " num_classes = 10,\n", + " num_selects = num_selects, \n", + " use_combiner = True,\n", + " comb_proj_size = None)" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "b93f9062", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "C:\\Users\\Chou\\Anaconda3\\envs\\pytorch\\lib\\site-packages\\torch\\nn\\functional.py:3631: UserWarning: Default upsampling behavior when mode=bilinear is changed to align_corners=False since 0.4.0. Please specify align_corners=True if the old behavior is desired. See the documentation of nn.Upsample for details.\n", + " warnings.warn(\n" + ] + } + ], + "source": [ + "rand_inp = torch.randn(1, 3, 224, 224)\n", + "outs = pim_model(rand_inp)" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "173f0049", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "['layer1', 'layer2', 'preds_1', 'preds_0', 'comb_outs']\n" + ] + } + ], + "source": [ + "print([name for name in outs])" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "7f48774e", + "metadata": {}, + "outputs": [], + "source": [ + "# 'layer1' : logits of 'layer1' , size [B, num_classes]\n", + "# 'layer2' : logits of 'layer2' , size [B, num_classes]\n", + "# 'preds_1'(dict) : logits of selected region, size [B, num_classes]\n", + "# 'preds_0'(dict) : logits of NOT selected region, size [B, num_classes]\n", + "# 'comb_outs' : logits of Combiner , size [B, num_classes]" + ] + }, + { + "cell_type": "markdown", + "id": "5678b51d", + "metadata": {}, + "source": [ + "### some error raise while get_graph_node_names() or create_feature_extractor()" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "id": "31ff6bbe", + "metadata": {}, + "outputs": [], + "source": [ + "### change model" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "id": "1cd7f87e", + "metadata": {}, + "outputs": [], + "source": [ + "class Model(nn.Module):\n", + " \n", + " def __init__(self):\n", + " super().__init__()\n", + " self.conv1 = nn.Sequential(\n", + " nn.Conv2d(3, 64, 3, padding=1),\n", + " nn.BatchNorm2d(64),\n", + " nn.ReLU(),\n", + " nn.Conv2d(64, 64, 3, stride=2, padding=1),\n", + " nn.BatchNorm2d(64),\n", + " nn.ReLU(),\n", + " )\n", + " self.conv2 = nn.Sequential(\n", + " nn.Conv2d(64, 128, 3, padding=1),\n", + " nn.BatchNorm2d(128),\n", + " nn.ReLU(),\n", + " nn.Conv2d(128, 128, 3, stride=2, padding=1),\n", + " nn.BatchNorm2d(128),\n", + " nn.ReLU()\n", + " )\n", + " self.pool = nn.AdaptiveAvgPool2d((1, 1))\n", + " self.classifier = nn.Linear(128, 10)\n", + " \n", + " def forward(self, x):\n", + " x1 = self.conv1(x)\n", + " x2 = self.conv2(x1)\n", + " x = self.pool(x2)\n", + " x = x.flatten(1)\n", + " x = self.classifier(x)\n", + " return {\"layer1\":x1, \"layer2\":x2}" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "id": "4bce0b80", + "metadata": {}, + "outputs": [], + "source": [ + "model = Model()" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "id": "9c233716", + "metadata": {}, + "outputs": [], + "source": [ + "## set return_nodes to None" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "id": "89c7ea39", + "metadata": {}, + "outputs": [], + "source": [ + "pim_model = \\\n", + "PluginMoodel(backbone = model,\n", + " return_nodes = None,\n", + " img_size = IMG_SIZE,\n", + " use_fpn = USE_FPN,\n", + " fpn_size = FPN_SIZE,\n", + " proj_type = PROJ_TYPE,\n", + " upsample_type = UPSAMPLE_TYPE,\n", + " use_selection = True,\n", + " num_classes = 10,\n", + " num_selects = num_selects, \n", + " use_combiner = True,\n", + " comb_proj_size = None)" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "id": "95fd924c", + "metadata": {}, + "outputs": [], + "source": [ + "rand_inp = torch.randn(1, 3, 224, 224)\n", + "outs = pim_model(rand_inp)" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "id": "2ec4ac34", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "['layer1', 'layer2', 'preds_1', 'preds_0', 'comb_outs']\n" + ] + } + ], + "source": [ + "print([name for name in outs])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a2bc12ee", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.10" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/.ipynb_checkpoints/infer-checkpoint.py b/.ipynb_checkpoints/infer-checkpoint.py new file mode 100644 index 0000000..282c295 --- /dev/null +++ b/.ipynb_checkpoints/infer-checkpoint.py @@ -0,0 +1,85 @@ +""" +infer version1.0 +2022.06.07 +""" +import torch +import torch.nn as nn +import torch.nn.functional as F +import contextlib +import wandb +import warnings + +from models.builder import MODEL_GETTER +from data.dataset import build_loader +from utils.costom_logger import timeLogger +from utils.config_utils import load_yaml, build_record_folder, get_args + +warnings.simplefilter("ignore") + + +def set_environment(args, tlogger): + print("Setting Environment...") + + args.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + + ### = = = = Dataset and Data Loader = = = = + tlogger.print("Building Dataloader....") + + _, val_loader = build_loader(args) + + print("[Only Evaluation]") + + tlogger.print() + + ### = = = = Model = = = = + tlogger.print("Building Model....") + model = MODEL_GETTER[args.model_name]( + use_fpn=args.use_fpn, + fpn_size=args.fpn_size, + use_selection=args.use_selection, + num_classes=args.num_classes, + num_selects=args.num_selects, + use_combiner=args.use_combiner, + ) # about return_nodes, we use our default setting + print(model) + + checkpoint = torch.load(args.pretrained, map_location=torch.device('cpu')) + model.load_state_dict(checkpoint['model']) + start_epoch = checkpoint['epoch'] + + + # model = torch.nn.DataParallel(model, device_ids=None) # device_ids : None --> use all gpus. + model.to(args.device) + tlogger.print() + + """ + if you have multi-gpu device, you can use torch.nn.DataParallel in single-machine multi-GPU + situation and use torch.nn.parallel.DistributedDataParallel to use multi-process parallelism. + more detail: https://pytorch.org/tutorials/beginner/dist_overview.html + """ + + + return val_loader, model + + +def main_test(args, tlogger): + """ + infer and confusion matrix + """ + + val_loader, model = set_environment(args, tlogger) + from eval import eval_and_cm + eval_and_cm(args, model, val_loader, tlogger) + + +if __name__ == "__main__": + tlogger = timeLogger() + + tlogger.print("Reading Config...") + args = get_args() + assert args.c != "", "Please provide config file (.yaml)" + load_yaml(args, args.c) + build_record_folder(args) + tlogger.print() + + main_test(args, tlogger) \ No newline at end of file diff --git a/.ipynb_checkpoints/infer_test-checkpoint.py b/.ipynb_checkpoints/infer_test-checkpoint.py new file mode 100644 index 0000000..08166a9 --- /dev/null +++ b/.ipynb_checkpoints/infer_test-checkpoint.py @@ -0,0 +1,218 @@ +import os +import torch +import pandas as pd +import torchvision.transforms as transforms +from PIL import Image, ImageFile +import warnings +import contextlib +from argparse import Namespace + +# -------------------------- 1. 导入项目核心模块(需确保路径正确)-------------------------- +# 若脚本不在项目根目录,需添加项目路径到Python环境(例如:sys.path.append("../")) +from models.builder import MODEL_GETTER +from utils.config_utils import load_yaml, get_args +from utils.costom_logger import timeLogger + +# -------------------------- 2. 工具函数:健壮图像加载器(处理损坏/截断图像)-------------------------- +warnings.filterwarnings("ignore", category=UserWarning, module="PIL.Image") +ImageFile.LOAD_TRUNCATED_IMAGES = True +_CORRUPTED_WARNED = set() + + +def robust_pil_loader(path, data_size): + """健壮的图像加载器:处理损坏图像,返回RGB格式PIL图像""" + try: + with Image.open(path) as img: + if img.mode == 'P': + img = img.convert('RGBA') + if img.mode != 'RGB': + img = img.convert('RGB') + img.load() + return img.copy() + except Exception as e: + if path not in _CORRUPTED_WARNED: + print(f"Warning: Skipping corrupted image {path}. Error: {e}") + _CORRUPTED_WARNED.add(path) + # 损坏图像返回灰色占位图(尺寸与模型输入一致) + return Image.new('RGB', (data_size, data_size), color=(128, 128, 128)) + + +# -------------------------- 3. 测试集Dataset(适配无标签单文件夹图像)-------------------------- +class TestDataset(torch.utils.data.Dataset): + def __init__(self, test_root: str, data_size: int): + self.test_root = test_root + self.data_size = data_size + # 测试集数据增强(与验证集一致,无随机操作,确保结果稳定) + self.transform = transforms.Compose([ + transforms.Resize((510, 510), Image.BILINEAR), # 与训练时的Resize一致 + transforms.CenterCrop((data_size, data_size)), # 固定中心裁剪 + transforms.ToTensor(), + transforms.Normalize(mean=[0.485, 0.456, 0.406], # ImageNet归一化(与训练一致) + std=[0.229, 0.224, 0.225]) + ]) + # 读取测试集所有图像的路径和文件名("id"为文件名) + self.image_infos = self._get_image_infos() + + def _get_image_infos(self): + """遍历测试集文件夹,收集图像路径和文件名""" + image_infos = [] + for filename in os.listdir(self.test_root): + file_path = os.path.join(self.test_root, filename) + if os.path.isfile(file_path): # 只保留文件(跳过子文件夹) + # "id"使用文件名(如"test_img_001.jpg"),后续直接存入CSV + image_infos.append({"path": file_path, "id": filename}) + print(f"[Test Dataset] Total images loaded: {len(image_infos)}") + return image_infos + + def __len__(self): + return len(self.image_infos) + + def __getitem__(self, index): + info = self.image_infos[index] + img_path = info["path"] + img_id = info["id"] # CSV的"id"列内容(文件名) + + # 加载图像并应用增强 + img = robust_pil_loader(img_path, self.data_size) + img_tensor = self.transform(img) + + return img_tensor, img_id # 返回:图像张量、图像ID(文件名) + + +# -------------------------- 4. 核心推理函数-------------------------- +def load_best_model(args, device): + """加载训练保存的best.pt模型""" + # 1. 构建与训练一致的模型结构 + model = MODEL_GETTER[args.model_name]( + use_fpn=args.use_fpn, + fpn_size=args.fpn_size, + use_selection=args.use_selection, + num_classes=args.num_classes, + num_selects=args.num_selects, + use_combiner=args.use_combiner + ) + # 2. 加载best.pt权重(路径:save_dir/backup/best.pt) + best_model_path = os.path.join(args.save_dir, "backup", "best.pt") +# best_model_path = os.path.join(args.save_dir, "backup", "last.pt") + if not os.path.exists(best_model_path): + raise FileNotFoundError(f"Best model not found! Path: {best_model_path}") + + checkpoint = torch.load(best_model_path, map_location=device) + model.load_state_dict(checkpoint["model"]) # 加载模型权重 + model.to(device) + model.eval() # 设为评估模式(关闭Dropout、BatchNorm固定) + print(f"Successfully loaded best model from: {best_model_path}") + return model + + +def get_class_names(train_root): + """获取类别名列表(与训练时的class_id对应,确保预测类别名正确)""" + # 训练时类别ID按文件夹排序生成,此处需保持一致 + class_names = sorted(os.listdir(train_root)) + print(f"[Class Info] Total classes: {len(class_names)}") + return class_names + + +def infer_test_set(args, device, tlogger): + """批量推理测试集,保存结果到CSV""" + # 1. 加载测试集 + tlogger.print("Loading test dataset...") + test_dataset = TestDataset( + test_root=args.test_root, + data_size=args.data_size + ) + test_loader = torch.utils.data.DataLoader( + test_dataset, + batch_size=args.batch_size, # 从配置文件取批量大小(适配GPU显存) + shuffle=False, # 测试集无需打乱 + num_workers=args.num_workers, # 从配置文件取线程数 + pin_memory=True # 加速数据传输到GPU + ) + + # 2. 加载best模型和类别名 + model = load_best_model(args, device) + class_names = get_class_names(args.train_root) # 从训练集路径获取类别名(确保ID对应) + + # 3. 初始化AMP上下文(与训练一致,若未启用则用nullcontext) + if args.use_amp: + amp_context = torch.cuda.amp.autocast + else: + amp_context = contextlib.nullcontext + + # 4. 批量推理 + tlogger.print("Start inferring test set...") + infer_results = [] # 存储最终结果:[{"id": "...", "class": "..."}] + total_batches = len(test_loader) + + with torch.no_grad(): # 关闭梯度计算(节省显存+加速推理) + for batch_idx, (img_tensors, img_ids) in enumerate(test_loader): + # 数据移到设备(GPU/CPU) + img_tensors = img_tensors.to(device) + + # 前向传播(获取模型输出) + with amp_context(): + outs = model(img_tensors) + + # 提取最终预测结果(根据PIM模型结构,优先取comb_outs;若无则取ori_out) + if "comb_outs" in outs: + pred_logits = outs["comb_outs"] # 组合器输出(训练时的主要预测) + elif "ori_out" in outs: + pred_logits = outs["ori_out"] # 原始输出(无PIM模块时) + else: + raise KeyError("Model output has no 'comb_outs' or 'ori_out'! Check model structure.") + + # 计算预测类别ID(取logits最大值对应的索引) + pred_class_ids = torch.argmax(pred_logits, dim=1).cpu().numpy() # 转CPU+Numpy + + # 映射类别ID到类别名,组装结果 + for img_id, pred_id in zip(img_ids, pred_class_ids): + pred_class_name = class_names[pred_id] # ID→类别名(如0→"Brewer_Blackbird") + infer_results.append({ + "id": img_id, # 列1:图像ID(文件名) + "class": pred_class_name # 列2:预测类别名 + }) + + # 打印推理进度 + if (batch_idx + 1) % args.log_freq == 0 or (batch_idx + 1) == total_batches: + progress = (batch_idx + 1) / total_batches * 100 + tlogger.print(f"Infer Progress: {progress:.1f}% (Batch {batch_idx + 1}/{total_batches})") + + # 5. 保存结果到CSV(路径:save_dir/test_infer_results.csv) + result_df = pd.DataFrame(infer_results) + csv_save_path = os.path.join(args.save_dir, "test_infer_results_adamw25.csv") + result_df.to_csv(csv_save_path, index=False, encoding="utf-8") # 不保存索引,UTF-8编码适配中文 + tlogger.print(f"Infer completed! Results saved to: {csv_save_path}") + + +# -------------------------- 5. 脚本入口(解析配置+启动推理)-------------------------- +if __name__ == "__main__": + # 初始化时间日志器(记录推理耗时) + tlogger = timeLogger() + tlogger.print("=" * 50) + tlogger.print("Starting Test Set Inference Script") + tlogger.print("=" * 50) + + # 1. 解析命令行参数(获取配置文件路径) + args = get_args() + # 断言:必须提供配置文件(yaml格式) + assert args.c != "", "Please provide config file via '-c your_config.yaml'!" + + # 2. 加载yaml配置文件(所有参数从配置文件读取,无需硬编码) + tlogger.print(f"Loading config file: {args.c}") + load_yaml(args, args.c) + + # 3. 检查关键配置(确保测试集路径和保存路径存在) + if not os.path.exists(args.test_root): + raise ValueError(f"Test set path not exist! Check 'test_root' in config: {args.test_root}") + if not os.path.exists(args.save_dir): + raise ValueError(f"Save directory not exist! Check 'save_dir' in config: {args.save_dir}") + + # 4. 设置设备(GPU优先,无GPU则用CPU) + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + tlogger.print(f"Using device: {device}") + + + # 5. 启动测试集推理 + infer_test_set(args, device, tlogger) + tlogger.print("Inference Script Finished!") + tlogger.print("=" * 50) \ No newline at end of file diff --git a/.ipynb_checkpoints/main-checkpoint.py b/.ipynb_checkpoints/main-checkpoint.py new file mode 100644 index 0000000..11dd5b6 --- /dev/null +++ b/.ipynb_checkpoints/main-checkpoint.py @@ -0,0 +1,297 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import contextlib +import wandb +import warnings + +from models.builder import MODEL_GETTER +from data.dataset import build_loader +from utils.costom_logger import timeLogger +from utils.config_utils import load_yaml, build_record_folder, get_args +from utils.lr_schedule import cosine_decay, adjust_lr, get_lr +from eval import evaluate, cal_train_metrics + +warnings.simplefilter("ignore") + +def eval_freq_schedule(args, epoch: int): + if epoch >= args.max_epochs * 0.95: + args.eval_freq = 1 + elif epoch >= args.max_epochs * 0.9: + args.eval_freq = 1 + elif epoch >= args.max_epochs * 0.8: + args.eval_freq = 2 + +def set_environment(args, tlogger): + + print("Setting Environment...") + + args.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + + ### = = = = Dataset and Data Loader = = = = + tlogger.print("Building Dataloader....") + + train_loader, val_loader = build_loader(args) + + if train_loader is None and val_loader is None: + raise ValueError("Find nothing to train or evaluate.") + + if train_loader is not None: + print(" Train Samples: {} (batch: {})".format(len(train_loader.dataset), len(train_loader))) + else: + # raise ValueError("Build train loader fail, please provide legal path.") + print(" Train Samples: 0 ~~~~~> [Only Evaluation]") + if val_loader is not None: + print(" Validation Samples: {} (batch: {})".format(len(val_loader.dataset), len(val_loader))) + else: + print(" Validation Samples: 0 ~~~~~> [Only Training]") + tlogger.print() + + ### = = = = Model = = = = + tlogger.print("Building Model....") + model = MODEL_GETTER[args.model_name]( + use_fpn = args.use_fpn, + fpn_size = args.fpn_size, + use_selection = args.use_selection, + num_classes = args.num_classes, + num_selects = args.num_selects, + use_combiner = args.use_combiner, + ) # about return_nodes, we use our default setting + if args.pretrained is not None: + checkpoint = torch.load(args.pretrained, map_location=torch.device('cpu')) + model.load_state_dict(checkpoint['model']) + start_epoch = checkpoint['epoch'] + else: + start_epoch = 0 + + # model = torch.nn.DataParallel(model, device_ids=None) # device_ids : None --> use all gpus. + model.to(args.device) + tlogger.print() + + """ + if you have multi-gpu device, you can use torch.nn.DataParallel in single-machine multi-GPU + situation and use torch.nn.parallel.DistributedDataParallel to use multi-process parallelism. + more detail: https://pytorch.org/tutorials/beginner/dist_overview.html + """ + + if train_loader is None: + return train_loader, val_loader, model, None, None, None, None + + ### = = = = Optimizer = = = = + tlogger.print("Building Optimizer....") + if args.optimizer == "SGD": + optimizer = torch.optim.SGD(model.parameters(), lr=args.max_lr, nesterov=True, momentum=0.9, weight_decay=args.wdecay) + elif args.optimizer == "AdamW": + optimizer = torch.optim.AdamW(model.parameters(), lr=args.max_lr) + + if args.pretrained is not None: + optimizer.load_state_dict(checkpoint['optimizer']) + + tlogger.print() + + schedule = cosine_decay(args, len(train_loader)) + + if args.use_amp: + scaler = torch.cuda.amp.GradScaler() + amp_context = torch.cuda.amp.autocast + else: + scaler = None + amp_context = contextlib.nullcontext + + return train_loader, val_loader, model, optimizer, schedule, scaler, amp_context, start_epoch + + +def train(args, epoch, model, scaler, amp_context, optimizer, schedule, train_loader): + + optimizer.zero_grad() + total_batchs = len(train_loader) # just for log + show_progress = [x/10 for x in range(11)] # just for log + progress_i = 0 + for batch_id, (ids, datas, labels) in enumerate(train_loader): + model.train() + """ = = = = adjust learning rate = = = = """ + iterations = epoch * len(train_loader) + batch_id + adjust_lr(iterations, optimizer, schedule) + + batch_size = labels.size(0) + + """ = = = = forward and calculate loss = = = = """ + datas, labels = datas.to(args.device), labels.to(args.device) + + with amp_context(): + """ + [Model Return] + FPN + Selector + Combiner --> return 'layer1', 'layer2', 'layer3', 'layer4', ...(depend on your setting) + 'preds_0', 'preds_1', 'comb_outs' + FPN + Selector --> return 'layer1', 'layer2', 'layer3', 'layer4', ...(depend on your setting) + 'preds_0', 'preds_1' + FPN --> return 'layer1', 'layer2', 'layer3', 'layer4' (depend on your setting) + ~ --> return 'ori_out' + + [Retuen Tensor] + 'preds_0': logit has not been selected by Selector. + 'preds_1': logit has been selected by Selector. + 'comb_outs': The prediction of combiner. + """ + outs = model(datas) + + loss = 0. + for name in outs: + + if "select_" in name: + if not args.use_selection: + raise ValueError("Selector not use here.") + if args.lambda_s != 0: + S = outs[name].size(1) + logit = outs[name].view(-1, args.num_classes).contiguous() + loss_s = nn.CrossEntropyLoss()(logit, + labels.unsqueeze(1).repeat(1, S).flatten(0)) + loss += args.lambda_s * loss_s + else: + loss_s = 0.0 + + elif "drop_" in name: + if not args.use_selection: + raise ValueError("Selector not use here.") + + if args.lambda_n != 0: + S = outs[name].size(1) + logit = outs[name].view(-1, args.num_classes).contiguous() + n_preds = nn.Tanh()(logit) + labels_0 = torch.zeros([batch_size * S, args.num_classes]) - 1 + labels_0 = labels_0.to(args.device) + loss_n = nn.MSELoss()(n_preds, labels_0) + loss += args.lambda_n * loss_n + else: + loss_n = 0.0 + + elif "layer" in name: + if not args.use_fpn: + raise ValueError("FPN not use here.") + if args.lambda_b != 0: + ### here using 'layer1'~'layer4' is default setting, you can change to your own + loss_b = nn.CrossEntropyLoss()(outs[name].mean(1), labels) + loss += args.lambda_b * loss_b + else: + loss_b = 0.0 + + elif "comb_outs" in name: + if not args.use_combiner: + raise ValueError("Combiner not use here.") + + if args.lambda_c != 0: + loss_c = nn.CrossEntropyLoss()(outs[name], labels) + loss += args.lambda_c * loss_c + + elif "ori_out" in name: + loss_ori = F.cross_entropy(outs[name], labels) + loss += loss_ori + + loss /= args.update_freq + + """ = = = = calculate gradient = = = = """ + if args.use_amp: + scaler.scale(loss).backward() + else: + loss.backward() + + """ = = = = update model = = = = """ + if (batch_id + 1) % args.update_freq == 0: + if args.use_amp: + scaler.step(optimizer) + scaler.update() # next batch + else: + optimizer.step() + optimizer.zero_grad() + + """ log (MISC) """ + if args.use_wandb and ((batch_id + 1) % args.log_freq == 0): + model.eval() + msg = {} + msg['info/epoch'] = epoch + 1 + msg['info/lr'] = get_lr(optimizer) + cal_train_metrics(args, msg, outs, labels, batch_size) + wandb.log(msg) + + train_progress = (batch_id + 1) / total_batchs + # print(train_progress, show_progress[progress_i]) + if train_progress > show_progress[progress_i]: + print(".."+str(int(show_progress[progress_i] * 100)) + "%", end='', flush=True) + progress_i += 1 + + +def main(args, tlogger): + """ + save model last.pt and best.pt + """ + + train_loader, val_loader, model, optimizer, schedule, scaler, amp_context, start_epoch = set_environment(args, tlogger) + + best_acc = 0.0 + best_eval_name = "null" + + if args.use_wandb: + wandb.init(entity=args.wandb_entity, + project=args.project_name, + name=args.exp_name, + config=args) + wandb.run.summary["best_acc"] = best_acc + wandb.run.summary["best_eval_name"] = best_eval_name + wandb.run.summary["best_epoch"] = 0 + + for epoch in range(start_epoch, args.max_epochs): + + """ + Train + """ + if train_loader is not None: + tlogger.print("Start Training {} Epoch".format(epoch+1)) + train(args, epoch, model, scaler, amp_context, optimizer, schedule, train_loader) + tlogger.print() + else: + from eval import eval_and_save + eval_and_save(args, model, val_loader) + break + + eval_freq_schedule(args, epoch) + + model_to_save = model.module if hasattr(model, "module") else model + checkpoint = {"model": model_to_save.state_dict(), "optimizer": optimizer.state_dict(), "epoch":epoch} + torch.save(checkpoint, args.save_dir + "backup/last.pt") + + if epoch == 0 or (epoch + 1) % args.eval_freq == 0: + """ + Evaluation + """ + acc = -1 + if val_loader is not None: + tlogger.print("Start Evaluating {} Epoch".format(epoch + 1)) + acc, eval_name, accs = evaluate(args, model, val_loader) + tlogger.print("....BEST_ACC: {}% ({}%)".format(max(acc, best_acc), acc)) + tlogger.print() + + if args.use_wandb: + wandb.log(accs) + + if acc > best_acc: + best_acc = acc + best_eval_name = eval_name + torch.save(checkpoint, args.save_dir + "backup/best.pt") + if args.use_wandb: + wandb.run.summary["best_acc"] = best_acc + wandb.run.summary["best_eval_name"] = best_eval_name + wandb.run.summary["best_epoch"] = epoch + 1 + + +if __name__ == "__main__": + + tlogger = timeLogger() + + tlogger.print("Reading Config...") + args = get_args() + assert args.c != "", "Please provide config file (.yaml)" + load_yaml(args, args.c) + build_record_folder(args) + tlogger.print() + + main(args, tlogger) \ No newline at end of file diff --git a/.ipynb_checkpoints/order-checkpoint.txt b/.ipynb_checkpoints/order-checkpoint.txt new file mode 100644 index 0000000..5e2dd47 --- /dev/null +++ b/.ipynb_checkpoints/order-checkpoint.txt @@ -0,0 +1,4 @@ +python vit_pim_main.py --c ./configs/fgvc_vit_smuaic.yaml +python infer_test.py --c ./configs/fgvc_vit_smuaic.yaml +python vit_pim_main_sgd.py --c ./configs/fgvc_vit_smuaic-strong.yaml +python infer_test.py --c ./configs/fgvc_vit_smuaic-strong.yaml \ No newline at end of file diff --git a/.ipynb_checkpoints/vit_fim_main-checkpoint.py b/.ipynb_checkpoints/vit_fim_main-checkpoint.py new file mode 100644 index 0000000..c92677f --- /dev/null +++ b/.ipynb_checkpoints/vit_fim_main-checkpoint.py @@ -0,0 +1,429 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import contextlib +import wandb +import warnings + +from models.builder import MODEL_GETTER +from data.dataset import build_loader +from utils.costom_logger import timeLogger +from utils.config_utils import load_yaml, build_record_folder, get_args +from utils.lr_schedule import cosine_decay, adjust_lr, get_lr +from eval import evaluate, cal_train_metrics + +warnings.simplefilter("ignore") + + +def eval_freq_schedule(args, epoch: int): + """ + 根据当前训练的 epoch 调整验证频率(eval_freq)。 + 在训练接近尾声时更频繁地进行验证,以便更好地监控模型性能。 + + 参数: + args: 包含训练配置参数的对象,其中包括 eval_freq 和 max_epochs。 + epoch: 当前训练的 epoch 数。 + """ + # 如果当前 epoch 大于等于最大训练轮次的 95%,则将验证频率设为 1(每个 epoch 都验证) + if epoch >= args.max_epochs * 0.95: + args.eval_freq = 1 + # 如果当前 epoch 大于等于最大训练轮次的 90% 但小于 95%,同样将验证频率设为 1 + elif epoch >= args.max_epochs * 0.9: + args.eval_freq = 1 + # 如果当前 epoch 大于等于最大训练轮次的 80% 但小于 90%,将验证频率设为 2(每两个 epoch 验证一次) + elif epoch >= args.max_epochs * 0.8: + args.eval_freq = 2 + + +def set_environment(args, tlogger): + """ + 设置训练环境,包括设备、数据加载器、模型、优化器等。 + + 参数: + args: 包含训练配置参数的对象。 + tlogger: 用于记录时间日志的对象。 + + 返回: + train_loader: 训练数据加载器。 + val_loader: 验证数据加载器。 + model: 构建并初始化的模型。 + optimizer: 优化器(如果仅评估则为None)。 + schedule: 学习率调度器(如果仅评估则为None)。 + scaler: AMP缩放器(如果不使用AMP则为None)。 + amp_context: AMP上下文管理器(如果不使用AMP则是nullcontext)。 + start_epoch: 训练开始的epoch数(如果有预训练模型,则从该模型的epoch开始)。 + """ + + print("Setting Environment...") + + # 设置训练设备:如果CUDA可用则使用GPU,否则使用CPU + args.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + + ### = = = = Dataset and Data Loader = = = = + # 构建训练和验证数据加载器 + tlogger.print("Building Dataloader....") + + train_loader, val_loader = build_loader(args) + + # 检查是否成功构建了数据加载器 + if train_loader is None and val_loader is None: + raise ValueError("Find nothing to train or evaluate.") + + # 打印训练集信息 + if train_loader is not None: + print(" Train Samples: {} (batch: {})".format(len(train_loader.dataset), len(train_loader))) + else: + # raise ValueError("Build train loader fail, please provide legal path.") + print(" Train Samples: 0 ~~~~~> [Only Evaluation]") + + # 打印验证集信息 + if val_loader is not None: + print(" Validation Samples: {} (batch: {})".format(len(val_loader.dataset), len(val_loader))) + else: + print(" Validation Samples: 0 ~~~~~> [Only Training]") + tlogger.print() + + ### = = = = Model = = = = + # 构建模型 + tlogger.print("Building Model....") + model = MODEL_GETTER[args.model_name]( + use_fpn=args.use_fpn, + fpn_size=args.fpn_size, + use_selection=args.use_selection, + num_classes=args.num_classes, + num_selects=args.num_selects, + use_combiner=args.use_combiner, + ) # about return_nodes, we use our default setting + + # 如果提供了预训练模型,则加载权重 + if args.pretrained is not None: + checkpoint = torch.load(args.pretrained, map_location=torch.device('cpu')) + model.load_state_dict(checkpoint['model']) + start_epoch = checkpoint['epoch'] + print(start_epoch) + else: + start_epoch = 0 + + # 将模型移动到指定设备 + model.to(args.device) + tlogger.print() + + """ + 如果你有多GPU设备,可以在单机多GPU情况下使用torch.nn.DataParallel, + 或者使用torch.nn.parallel.DistributedDataParallel实现多进程并行。 + 更多详情:https://pytorch.org/tutorials/beginner/dist_overview.html + """ + + # 如果没有训练数据加载器,只进行评估,返回部分对象 + if train_loader is None: + return train_loader, val_loader, model, None, None, None, None, start_epoch + + ### = = = = Optimizer = = = = + # 构建优化器 + tlogger.print("Building Optimizer....") + if args.optimizer == "SGD": + optimizer = torch.optim.SGD(model.parameters(), lr=args.max_lr, nesterov=True, momentum=0.9, + weight_decay=args.wdecay) + elif args.optimizer == "AdamW": + optimizer = torch.optim.AdamW(model.parameters(), lr=args.max_lr) + + # 如果有预训练模型,加载优化器状态 + if args.pretrained is not None: + optimizer.load_state_dict(checkpoint['optimizer']) + + tlogger.print() + + # 构建学习率调度器 + schedule = cosine_decay(args, len(train_loader)) + + # 如果使用混合精度训练(AMP),设置相关的组件 + if args.use_amp: + scaler = torch.cuda.amp.GradScaler() + amp_context = torch.cuda.amp.autocast + else: + scaler = None + amp_context = contextlib.nullcontext + + # 返回所有构建的组件 + return train_loader, val_loader, model, optimizer, schedule, scaler, amp_context, start_epoch + + +def train(args, epoch, model, scaler, amp_context, optimizer, schedule, train_loader): + """ + 训练函数,在每个epoch中迭代训练数据并更新模型参数。 + + 参数: + args: 包含训练配置参数的对象。 + epoch: 当前训练的 epoch 数。 + model: 要训练的模型。 + scaler: AMP缩放器(如果不使用AMP则为None)。 + amp_context: AMP上下文管理器(如果不使用AMP则是nullcontext)。 + optimizer: 优化器。 + schedule: 学习率调度器。 + train_loader: 训练数据加载器。 + """ + + # 清空优化器的梯度 + optimizer.zero_grad() + + # 获取总批次数,仅用于日志记录 + total_batchs = len(train_loader) + + # 定义训练进度显示点(0%, 10%, ..., 100%) + show_progress = [x / 10 for x in range(11)] + progress_i = 0 + + # 遍历训练数据加载器中的每个批次 + for batch_id, (ids, datas, labels) in enumerate(train_loader): + # 设置模型为训练模式 + model.train() + + """ = = = = adjust learning rate = = = = """ + # 计算当前迭代次数 + iterations = epoch * len(train_loader) + batch_id + # 调整学习率 + adjust_lr(iterations, optimizer, schedule) + + # 获取当前批次的样本数量 + batch_size = labels.size(0) + + """ = = = = forward and calculate loss = = = = """ + # 将数据和标签移动到指定设备 + datas, labels = datas.to(args.device), labels.to(args.device) + + # 使用AMP上下文进行前向传播(如果启用AMP) + with amp_context(): + """ + [Model Return] + FPN + Selector + Combiner --> return 'layer1', 'layer2', 'layer3', 'layer4', ...(depend on your setting) + 'preds_0', 'preds_1', 'comb_outs' + FPN + Selector --> return 'layer1', 'layer2', 'layer3', 'layer4', ...(depend on your setting) + 'preds_0', 'preds_1' + FPN --> return 'layer1', 'layer2', 'layer3', 'layer4' (depend on your setting) + ~ --> return 'ori_out' + + [Retuen Tensor] + 'preds_0': logit has not been selected by Selector. + 'preds_1': logit has been selected by Selector. + 'comb_outs': The prediction of combiner. + """ + # 前向传播获取输出 + outs = model(datas) + + # 初始化总损失 + loss = 0. + + # 遍历模型输出的各个部分,计算相应的损失 + for name in outs: + # 处理选择器的输出 + if "select_" in name: + if not args.use_selection: + raise ValueError("Selector not use here.") + if args.lambda_s != 0: + # 计算选择器损失 + S = outs[name].size(1) + logit = outs[name].view(-1, args.num_classes).contiguous() + loss_s = nn.CrossEntropyLoss()(logit, + labels.unsqueeze(1).repeat(1, S).flatten(0)) + loss += args.lambda_s * loss_s + else: + loss_s = 0.0 + + # 处理丢弃部分的输出 + elif "drop_" in name: + if not args.use_selection: + raise ValueError("Selector not use here.") + + if args.lambda_n != 0: + # 计算负样本损失 + S = outs[name].size(1) + logit = outs[name].view(-1, args.num_classes).contiguous() + n_preds = nn.Tanh()(logit) + labels_0 = torch.zeros([batch_size * S, args.num_classes]) - 1 + labels_0 = labels_0.to(args.device) + loss_n = nn.MSELoss()(n_preds, labels_0) + loss += args.lambda_n * loss_n + else: + loss_n = 0.0 + + # 处理FPN层的输出 + elif "layer" in name: + if not args.use_fpn: + raise ValueError("FPN not use here.") + if args.lambda_b != 0: + # 计算FPN基础损失 + ### here using 'layer1'~'layer4' is default setting, you can change to your own + loss_b = nn.CrossEntropyLoss()(outs[name].mean(1), labels) + loss += args.lambda_b * loss_b + else: + loss_b = 0.0 + + # 处理组合器的输出 + elif "comb_outs" in name: + if not args.use_combiner: + raise ValueError("Combiner not use here.") + + if args.lambda_c != 0: + # 计算组合器损失 + loss_c = nn.CrossEntropyLoss()(outs[name], labels) + loss += args.lambda_c * loss_c + + # 处理原始输出 + elif "ori_out" in name: + # 计算原始输出损失 + loss_ori = F.cross_entropy(outs[name], labels) + loss += loss_ori + + # 对损失进行平均化处理 + loss /= args.update_freq + + """ = = = = calculate gradient = = = = """ + # 计算梯度(根据是否使用AMP选择不同的方式) + if args.use_amp: + scaler.scale(loss).backward() + else: + loss.backward() + + """ = = = = update model = = = = """ + # 更新模型参数(每隔update_freq个批次更新一次) + if (batch_id + 1) % args.update_freq == 0: + if args.use_amp: + # 使用AMP更新模型 + scaler.step(optimizer) + scaler.update() # next batch + else: + # 正常更新模型 + optimizer.step() + # 清空梯度 + optimizer.zero_grad() + + """ log (MISC) """ + # 记录训练日志(如果启用wandb且达到记录频率) + if args.use_wandb and ((batch_id + 1) % args.log_freq == 0): + # 切换到评估模式进行日志记录 + model.eval() + msg = {} + msg['info/epoch'] = epoch + 1 + msg['info/lr'] = get_lr(optimizer) + # 计算并记录训练指标 + cal_train_metrics(args, msg, outs, labels, batch_size) + # 将日志信息发送到wandb + wandb.log(msg) + + # 显示训练进度 + train_progress = (batch_id + 1) / total_batchs + # print(train_progress, show_progress[progress_i]) + if train_progress > show_progress[progress_i]: + print(".." + str(int(show_progress[progress_i] * 100)) + "%", end='', flush=True) + progress_i += 1 + + +def main(args, tlogger): + """ + 主训练循环函数,负责整个训练和验证过程,包括模型保存(last.pt 和 best.pt)。 + + 参数: + args: 包含训练配置参数的对象。 + tlogger: 用于记录时间日志的对象。 + """ + + # 调用set_environment函数设置训练环境,获取数据加载器、模型、优化器等 + train_loader, val_loader, model, optimizer, schedule, scaler, amp_context, start_epoch = set_environment(args, + tlogger) + + # 初始化最佳准确率和最佳评估名称 + best_acc = 0.0 + best_eval_name = "null" + + # 如果启用wandb,则初始化wandb项目并设置初始摘要信息 + if args.use_wandb: + wandb.init(entity=args.wandb_entity, + project=args.project_name, + name=args.exp_name, + config=args) + wandb.run.summary["best_acc"] = best_acc + wandb.run.summary["best_eval_name"] = best_eval_name + wandb.run.summary["best_epoch"] = 0 + + # 开始训练循环,从start_epoch到max_epochs + for epoch in range(start_epoch, args.max_epochs): + + """ + 训练阶段 + """ + # 如果存在训练数据加载器,则进行训练 + if train_loader is not None: + tlogger.print("Start Training {} Epoch".format(epoch + 1)) + # 调用train函数进行一个epoch的训练 + train(args, epoch, model, scaler, amp_context, optimizer, schedule, train_loader) + tlogger.print() + else: + # 如果没有训练数据加载器(仅评估模式),则调用eval_and_save进行评估并保存结果,然后退出循环 + from eval import eval_and_save + eval_and_save(args, model, val_loader) + break + + # 根据当前epoch调整验证频率 + eval_freq_schedule(args, epoch) + + # 准备要保存的模型检查点(处理多GPU情况) + model_to_save = model.module if hasattr(model, "module") else model + checkpoint = {"model": model_to_save.state_dict(), "optimizer": optimizer.state_dict(), "epoch": epoch} + # 保存最新的模型检查点 + torch.save(checkpoint, args.save_dir + "backup/last.pt") + + # 根据评估频率进行验证(每个epoch或每隔几个epoch) + if epoch == 0 or (epoch + 1) % args.eval_freq == 0: + """ + 验证阶段 + """ + acc = -1 + # 如果存在验证数据加载器,则进行验证 + if val_loader is not None: + tlogger.print("Start Evaluating {} Epoch".format(epoch + 1)) + # 调用evaluate函数进行验证,获取准确率等信息 + acc, eval_name, accs = evaluate(args, model, val_loader) + # 打印当前验证结果和历史最佳准确率 + tlogger.print("....BEST_ACC: {}% ({}%)".format(max(acc, best_acc), acc)) + tlogger.print() + + # 如果启用wandb,则记录验证指标 + if args.use_wandb: + wandb.log(accs) + + # 如果当前准确率优于历史最佳准确率,则更新最佳准确率并保存最佳模型 + if acc > best_acc: + best_acc = acc + best_eval_name = eval_name + torch.save(checkpoint, args.save_dir + "backup/best.pt") + # 如果启用wandb,则更新wandb中的最佳指标摘要 + if args.use_wandb: + wandb.run.summary["best_acc"] = best_acc + wandb.run.summary["best_eval_name"] = best_eval_name + wandb.run.summary["best_epoch"] = epoch + 1 + + +if __name__ == "__main__": + # 创建一个时间记录器实例,用于记录和打印时间相关的日志 + tlogger = timeLogger() + + # 打印正在读取配置文件的信息 + tlogger.print("Reading Config...") + + # 获取命令行参数,这些参数包括配置文件路径等 + args = get_args() + + # 断言确保提供了配置文件(.yaml格式),如果没有提供则抛出错误信息 + assert args.c != "", "Please provide config file (.yaml)" + + # 加载指定的YAML配置文件,并将配置内容存入args对象中 + load_yaml(args, args.c) + + # 根据配置创建记录文件夹,用于保存训练过程中的日志、模型等文件 + build_record_folder(args) + + # 打印空行,起到分隔日志的作用 + tlogger.print() + + # 调用main函数开始执行主要的训练或评估流程,传入解析后的参数和时间记录器 + main(args, tlogger) \ No newline at end of file diff --git a/.ipynb_checkpoints/vit_pim_main-checkpoint.py b/.ipynb_checkpoints/vit_pim_main-checkpoint.py new file mode 100644 index 0000000..1065083 --- /dev/null +++ b/.ipynb_checkpoints/vit_pim_main-checkpoint.py @@ -0,0 +1,428 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import contextlib +import wandb +import warnings + +from models.builder import MODEL_GETTER +from data.dataset import build_loader +from utils.costom_logger import timeLogger +from utils.config_utils import load_yaml, build_record_folder, get_args +from utils.lr_schedule import cosine_decay, adjust_lr, get_lr +from eval import evaluate, cal_train_metrics + +warnings.simplefilter("ignore") + + +def eval_freq_schedule(args, epoch: int): + """ + 根据当前训练的 epoch 调整验证频率(eval_freq)。 + 在训练接近尾声时更频繁地进行验证,以便更好地监控模型性能。 + + 参数: + args: 包含训练配置参数的对象,其中包括 eval_freq 和 max_epochs。 + epoch: 当前训练的 epoch 数。 + """ + # 如果当前 epoch 大于等于最大训练轮次的 95%,则将验证频率设为 1(每个 epoch 都验证) + if epoch >= args.max_epochs * 0.95: + args.eval_freq = 1 + # 如果当前 epoch 大于等于最大训练轮次的 90% 但小于 95%,同样将验证频率设为 1 + elif epoch >= args.max_epochs * 0.9: + args.eval_freq = 1 + # 如果当前 epoch 大于等于最大训练轮次的 80% 但小于 90%,将验证频率设为 2(每两个 epoch 验证一次) + elif epoch >= args.max_epochs * 0.8: + args.eval_freq = 2 + + +def set_environment(args, tlogger): + """ + 设置训练环境,包括设备、数据加载器、模型、优化器等。 + + 参数: + args: 包含训练配置参数的对象。 + tlogger: 用于记录时间日志的对象。 + + 返回: + train_loader: 训练数据加载器。 + val_loader: 验证数据加载器。 + model: 构建并初始化的模型。 + optimizer: 优化器(如果仅评估则为None)。 + schedule: 学习率调度器(如果仅评估则为None)。 + scaler: AMP缩放器(如果不使用AMP则为None)。 + amp_context: AMP上下文管理器(如果不使用AMP则是nullcontext)。 + start_epoch: 训练开始的epoch数(如果有预训练模型,则从该模型的epoch开始)。 + """ + + print("Setting Environment...") + + # 设置训练设备:如果CUDA可用则使用GPU,否则使用CPU + args.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + + ### = = = = Dataset and Data Loader = = = = + # 构建训练和验证数据加载器 + tlogger.print("Building Dataloader....") + + train_loader, val_loader = build_loader(args) + + # 检查是否成功构建了数据加载器 + if train_loader is None and val_loader is None: + raise ValueError("Find nothing to train or evaluate.") + + # 打印训练集信息 + if train_loader is not None: + print(" Train Samples: {} (batch: {})".format(len(train_loader.dataset), len(train_loader))) + else: + # raise ValueError("Build train loader fail, please provide legal path.") + print(" Train Samples: 0 ~~~~~> [Only Evaluation]") + + # 打印验证集信息 + if val_loader is not None: + print(" Validation Samples: {} (batch: {})".format(len(val_loader.dataset), len(val_loader))) + else: + print(" Validation Samples: 0 ~~~~~> [Only Training]") + tlogger.print() + + ### = = = = Model = = = = + # 构建模型 + tlogger.print("Building Model....") + model = MODEL_GETTER[args.model_name]( + use_fpn=args.use_fpn, + fpn_size=args.fpn_size, + use_selection=args.use_selection, + num_classes=args.num_classes, + num_selects=args.num_selects, + use_combiner=args.use_combiner, + ) # about return_nodes, we use our default setting + + # 如果提供了预训练模型,则加载权重 + if args.pretrained is not None: + checkpoint = torch.load(args.pretrained, map_location=torch.device('cpu')) + model.load_state_dict(checkpoint['model']) + start_epoch = checkpoint['epoch'] + print(start_epoch) + else: + start_epoch = 0 + + # 将模型移动到指定设备 + model.to(args.device) + tlogger.print() + + """ + 如果你有多GPU设备,可以在单机多GPU情况下使用torch.nn.DataParallel, + 或者使用torch.nn.parallel.DistributedDataParallel实现多进程并行。 + 更多详情:https://pytorch.org/tutorials/beginner/dist_overview.html + """ + + # 如果没有训练数据加载器,只进行评估,返回部分对象 + if train_loader is None: + return train_loader, val_loader, model, None, None, None, None, start_epoch + + ### = = = = Optimizer = = = = + # 构建优化器 + tlogger.print("Building Optimizer....") + if args.optimizer == "SGD": + optimizer = torch.optim.SGD(model.parameters(), lr=args.max_lr, nesterov=True, momentum=0.9, + weight_decay=args.wdecay) + elif args.optimizer == "AdamW": + optimizer = torch.optim.AdamW(model.parameters(), lr=args.max_lr) + + # 如果有预训练模型,加载优化器状态 + if args.pretrained is not None: + optimizer.load_state_dict(checkpoint['optimizer']) + + tlogger.print() + + # 构建学习率调度器 + schedule = cosine_decay(args, len(train_loader)) + + # 如果使用混合精度训练(AMP),设置相关的组件 + if args.use_amp: + scaler = torch.cuda.amp.GradScaler() + amp_context = torch.cuda.amp.autocast + else: + scaler = None + amp_context = contextlib.nullcontext + + # 返回所有构建的组件 + return train_loader, val_loader, model, optimizer, schedule, scaler, amp_context, start_epoch + + +def train(args, epoch, model, scaler, amp_context, optimizer, schedule, train_loader): + """ + 训练函数,在每个epoch中迭代训练数据并更新模型参数。 + + 参数: + args: 包含训练配置参数的对象。 + epoch: 当前训练的 epoch 数。 + model: 要训练的模型。 + scaler: AMP缩放器(如果不使用AMP则为None)。 + amp_context: AMP上下文管理器(如果不使用AMP则是nullcontext)。 + optimizer: 优化器。 + schedule: 学习率调度器。 + train_loader: 训练数据加载器。 + """ + + # 清空优化器的梯度 + optimizer.zero_grad() + + # 获取总批次数,仅用于日志记录 + total_batchs = len(train_loader) + + # 定义训练进度显示点(0%, 10%, ..., 100%) + show_progress = [x / 10 for x in range(11)] + progress_i = 0 + + # 遍历训练数据加载器中的每个批次 + for batch_id, (ids, datas, labels) in enumerate(train_loader): + # 设置模型为训练模式 + model.train() + + """ = = = = adjust learning rate = = = = """ + # 计算当前迭代次数 + iterations = epoch * len(train_loader) + batch_id + # 调整学习率 + adjust_lr(iterations, optimizer, schedule) + + # 获取当前批次的样本数量 + batch_size = labels.size(0) + + """ = = = = forward and calculate loss = = = = """ + # 将数据和标签移动到指定设备 + datas, labels = datas.to(args.device), labels.to(args.device) + + # 使用AMP上下文进行前向传播(如果启用AMP) + with amp_context(): + """ + [Model Return] + FPN + Selector + Combiner --> return 'layer1', 'layer2', 'layer3', 'layer4', ...(depend on your setting) + 'preds_0', 'preds_1', 'comb_outs' + FPN + Selector --> return 'layer1', 'layer2', 'layer3', 'layer4', ...(depend on your setting) + 'preds_0', 'preds_1' + FPN --> return 'layer1', 'layer2', 'layer3', 'layer4' (depend on your setting) + ~ --> return 'ori_out' + + [Retuen Tensor] + 'preds_0': logit has not been selected by Selector. + 'preds_1': logit has been selected by Selector. + 'comb_outs': The prediction of combiner. + """ + # 前向传播获取输出 + outs = model(datas) + + # 初始化总损失 + loss = 0. + + # 遍历模型输出的各个部分,计算相应的损失 + for name in outs: + # 处理选择器的输出 + if "select_" in name: + if not args.use_selection: + raise ValueError("Selector not use here.") + if args.lambda_s != 0: + # 计算选择器损失 + S = outs[name].size(1) + logit = outs[name].view(-1, args.num_classes).contiguous() + loss_s = nn.CrossEntropyLoss()(logit, + labels.unsqueeze(1).repeat(1, S).flatten(0)) + loss += args.lambda_s * loss_s + else: + loss_s = 0.0 + + # 处理丢弃部分的输出 + elif "drop_" in name: + if not args.use_selection: + raise ValueError("Selector not use here.") + + if args.lambda_n != 0: + # 计算负样本损失 + S = outs[name].size(1) + logit = outs[name].view(-1, args.num_classes).contiguous() + n_preds = nn.Tanh()(logit) + labels_0 = torch.zeros([batch_size * S, args.num_classes]) - 1 + labels_0 = labels_0.to(args.device) + loss_n = nn.MSELoss()(n_preds, labels_0) + loss += args.lambda_n * loss_n + else: + loss_n = 0.0 + + # 处理FPN层的输出 + elif "layer" in name: + if not args.use_fpn: + raise ValueError("FPN not use here.") + if args.lambda_b != 0: + # 计算FPN基础损失 + ### here using 'layer1'~'layer4' is default setting, you can change to your own + loss_b = nn.CrossEntropyLoss()(outs[name].mean(1), labels) + loss += args.lambda_b * loss_b + else: + loss_b = 0.0 + + # 处理组合器的输出 + elif "comb_outs" in name: + if not args.use_combiner: + raise ValueError("Combiner not use here.") + + if args.lambda_c != 0: + # 计算组合器损失 + loss_c = nn.CrossEntropyLoss()(outs[name], labels) + loss += args.lambda_c * loss_c + + # 处理原始输出 + elif "ori_out" in name: + # 计算原始输出损失 + loss_ori = F.cross_entropy(outs[name], labels) + loss += loss_ori + + # 对损失进行平均化处理 + loss /= args.update_freq + + """ = = = = calculate gradient = = = = """ + # 计算梯度(根据是否使用AMP选择不同的方式) + if args.use_amp: + scaler.scale(loss).backward() + else: + loss.backward() + + """ = = = = update model = = = = """ + # 更新模型参数(每隔update_freq个批次更新一次) + if (batch_id + 1) % args.update_freq == 0: + if args.use_amp: + # 使用AMP更新模型 + scaler.step(optimizer) + scaler.update() # next batch + else: + # 正常更新模型 + optimizer.step() + # 清空梯度 + optimizer.zero_grad() + + """ log (MISC) """ + # 记录训练日志(如果启用wandb且达到记录频率) + if args.use_wandb and ((batch_id + 1) % args.log_freq == 0): + # 切换到评估模式进行日志记录 + model.eval() + msg = {} + msg['info/epoch'] = epoch + 1 + msg['info/lr'] = get_lr(optimizer) + # 计算并记录训练指标 + cal_train_metrics(args, msg, outs, labels, batch_size) + # 将日志信息发送到wandb + wandb.log(msg) + + # 显示训练进度 + train_progress = (batch_id + 1) / total_batchs + # print(train_progress, show_progress[progress_i]) + if train_progress > show_progress[progress_i]: + print(".." + str(int(show_progress[progress_i] * 100)) + "%", end='', flush=True) + progress_i += 1 + +def main(args, tlogger): + """ + 主训练循环函数,负责整个训练和验证过程,包括模型保存(last.pt 和 best.pt)。 + + 参数: + args: 包含训练配置参数的对象。 + tlogger: 用于记录时间日志的对象。 + """ + + # 调用set_environment函数设置训练环境,获取数据加载器、模型、优化器等 + train_loader, val_loader, model, optimizer, schedule, scaler, amp_context, start_epoch = set_environment(args, + tlogger) + + # 初始化最佳准确率和最佳评估名称 + best_acc = 0.0 + best_eval_name = "null" + + # 如果启用wandb,则初始化wandb项目并设置初始摘要信息 + if args.use_wandb: + wandb.init(entity=args.wandb_entity, + project=args.project_name, + name=args.exp_name, + config=args) + wandb.run.summary["best_acc"] = best_acc + wandb.run.summary["best_eval_name"] = best_eval_name + wandb.run.summary["best_epoch"] = 0 + + # 开始训练循环,从start_epoch到max_epochs + for epoch in range(start_epoch, args.max_epochs): + + """ + 训练阶段 + """ + # 如果存在训练数据加载器,则进行训练 + if train_loader is not None: + tlogger.print("Start Training {} Epoch".format(epoch + 1)) + # 调用train函数进行一个epoch的训练 + train(args, epoch, model, scaler, amp_context, optimizer, schedule, train_loader) + tlogger.print() + else: + # 如果没有训练数据加载器(仅评估模式),则调用eval_and_save进行评估并保存结果,然后退出循环 + from eval import eval_and_save + eval_and_save(args, model, val_loader) + break + + # 根据当前epoch调整验证频率 + eval_freq_schedule(args, epoch) + + # 准备要保存的模型检查点(处理多GPU情况) + model_to_save = model.module if hasattr(model, "module") else model + checkpoint = {"model": model_to_save.state_dict(), "optimizer": optimizer.state_dict(), "epoch": epoch} + # 保存最新的模型检查点 + torch.save(checkpoint, args.save_dir + "backup/last.pt") + + # 根据评估频率进行验证(每个epoch或每隔几个epoch) + if epoch == 0 or (epoch + 1) % args.eval_freq == 0: + """ + 验证阶段 + """ + acc = -1 + # 如果存在验证数据加载器,则进行验证 + if val_loader is not None: + tlogger.print("Start Evaluating {} Epoch".format(epoch + 1)) + # 调用evaluate函数进行验证,获取准确率等信息 + acc, eval_name, accs = evaluate(args, model, val_loader) + # 打印当前验证结果和历史最佳准确率 + tlogger.print("....BEST_ACC: {}% ({}%)".format(max(acc, best_acc), acc)) + tlogger.print() + + # 如果启用wandb,则记录验证指标 + if args.use_wandb: + wandb.log(accs) + + # 如果当前准确率优于历史最佳准确率,则更新最佳准确率并保存最佳模型 + if acc > best_acc: + best_acc = acc + best_eval_name = eval_name + torch.save(checkpoint, args.save_dir + "backup/best.pt") + # 如果启用wandb,则更新wandb中的最佳指标摘要 + if args.use_wandb: + wandb.run.summary["best_acc"] = best_acc + wandb.run.summary["best_eval_name"] = best_eval_name + wandb.run.summary["best_epoch"] = epoch + 1 + + +if __name__ == "__main__": + # 创建一个时间记录器实例,用于记录和打印时间相关的日志 + tlogger = timeLogger() + + # 打印正在读取配置文件的信息 + tlogger.print("Reading Config...") + + # 获取命令行参数,这些参数包括配置文件路径等 + args = get_args() + + # 断言确保提供了配置文件(.yaml格式),如果没有提供则抛出错误信息 + assert args.c != "", "Please provide config file (.yaml)" + + # 加载指定的YAML配置文件,并将配置内容存入args对象中 + load_yaml(args, args.c) + + # 根据配置创建记录文件夹,用于保存训练过程中的日志、模型等文件 + build_record_folder(args) + + # 打印空行,起到分隔日志的作用 + tlogger.print() + + # 调用main函数开始执行主要的训练或评估流程,传入解析后的参数和时间记录器 + main(args, tlogger) \ No newline at end of file diff --git a/.ipynb_checkpoints/vit_pim_main_adamw-checkpoint.py b/.ipynb_checkpoints/vit_pim_main_adamw-checkpoint.py new file mode 100644 index 0000000..e2bafea --- /dev/null +++ b/.ipynb_checkpoints/vit_pim_main_adamw-checkpoint.py @@ -0,0 +1,291 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import contextlib +import wandb +import warnings + +from models.builder import MODEL_GETTER +from data.dataset import build_loader +from utils.costom_logger import timeLogger +from utils.config_utils import load_yaml, build_record_folder, get_args +from utils.lr_schedule import cosine_decay, adjust_lr, get_lr +from eval import evaluate, cal_train_metrics, _average_top_k_result # 确保导入_average_top_k_result + +warnings.simplefilter("ignore") + + +def eval_freq_schedule(args, epoch: int): + if epoch >= args.max_epochs * 0.95: + args.eval_freq = 1 + elif epoch >= args.max_epochs * 0.9: + args.eval_freq = 1 + elif epoch >= args.max_epochs * 0.8: + args.eval_freq = 2 + + +def set_environment(args, tlogger): + print("Setting Environment...") + args.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + + ### Dataset and Data Loader + tlogger.print("Building Dataloader....") + train_loader, val_loader = build_loader(args) + + if train_loader is None and val_loader is None: + raise ValueError("Find nothing to train or evaluate.") + + if train_loader is not None: + print(" Train Samples: {} (batch: {})".format(len(train_loader.dataset), len(train_loader))) + else: + print(" Train Samples: 0 ~~~~~> [Only Evaluation]") + if val_loader is not None: + print(" Validation Samples: {} (batch: {})".format(len(val_loader.dataset), len(val_loader))) + else: + print(" Validation Samples: 0 ~~~~~> [Only Training]") + tlogger.print() + + ### Model + tlogger.print("Building Model....") + model = MODEL_GETTER[args.model_name]( + use_fpn=args.use_fpn, + fpn_size=args.fpn_size, + use_selection=args.use_selection, + num_classes=args.num_classes, + num_selects=args.num_selects, + use_combiner=args.use_combiner, + ) + + start_epoch = 0 + if args.pretrained is not None: + checkpoint = torch.load(args.pretrained, map_location=torch.device('cpu')) + model.load_state_dict(checkpoint['model'], strict=False) + start_epoch = checkpoint['epoch'] + print(f"Loaded pretrained model from epoch {start_epoch}") + + model.to(args.device) + tlogger.print() + + if train_loader is None: + return train_loader, val_loader, model, None, None, None, None, start_epoch + + ### Optimizer + tlogger.print("Building Optimizer....") + if args.optimizer == "SGD": + optimizer = torch.optim.SGD(model.parameters(), lr=args.max_lr, nesterov=True, momentum=0.9, + weight_decay=args.wdecay) + elif args.optimizer == "AdamW": + optimizer = torch.optim.AdamW(model.parameters(), lr=args.max_lr) + + if args.pretrained is not None: + optimizer.load_state_dict(checkpoint['optimizer']) + + tlogger.print() + schedule = cosine_decay(args, len(train_loader)) + + ### AMP + if args.use_amp: + scaler = torch.cuda.amp.GradScaler() + amp_context = torch.cuda.amp.autocast + else: + scaler = None + amp_context = contextlib.nullcontext + + return train_loader, val_loader, model, optimizer, schedule, scaler, amp_context, start_epoch + + +def train(args, epoch, model, scaler, amp_context, optimizer, schedule, train_loader): + """修正训练Acc计算逻辑,打印训练Loss和Acc""" + optimizer.zero_grad() + total_batchs = len(train_loader) + show_progress = [x / 10 for x in range(11)] + progress_i = 0 + + # 记录整个epoch的训练指标 + epoch_total_loss = 0.0 # 总训练Loss + epoch_correct = 0 # 总正确样本数(用于计算平均Acc) + total_samples = 0 # 总训练样本数 + + for batch_id, (ids, datas, labels) in enumerate(train_loader): + model.train() + batch_size = labels.size(0) + total_samples += batch_size + labels = labels.to(args.device) # 移动标签到设备 + + ### 调整学习率 + iterations = epoch * len(train_loader) + batch_id + adjust_lr(iterations, optimizer, schedule) + + ### 前向传播与损失计算 + datas = datas.to(args.device) + batch_loss = 0.0 # 当前batch的总Loss + + with amp_context(): + outs = model(datas) + loss = 0. + + # 计算各部分Loss(与原有逻辑一致) + for name in outs: + if "select_" in name: + if not args.use_selection: + raise ValueError("Selector not use here.") + if args.lambda_s != 0: + S = outs[name].size(1) + logit = outs[name].view(-1, args.num_classes).contiguous() + loss_s = nn.CrossEntropyLoss()(logit, labels.unsqueeze(1).repeat(1, S).flatten(0)) + loss += args.lambda_s * loss_s + batch_loss += loss_s.item() * args.lambda_s + + elif "drop_" in name: + if not args.use_selection: + raise ValueError("Selector not use here.") + if args.lambda_n != 0: + S = outs[name].size(1) + logit = outs[name].view(-1, args.num_classes).contiguous() + n_preds = nn.Tanh()(logit) + labels_0 = torch.zeros([batch_size * S, args.num_classes]).to(args.device) - 1 + loss_n = nn.MSELoss()(n_preds, labels_0) + loss += args.lambda_n * loss_n + batch_loss += loss_n.item() * args.lambda_n + + elif "layer" in name: + if not args.use_fpn: + raise ValueError("FPN not use here.") + if args.lambda_b != 0: + loss_b = nn.CrossEntropyLoss()(outs[name].mean(1), labels) + loss += args.lambda_b * loss_b + batch_loss += loss_b.item() * args.lambda_b + + elif "comb_outs" in name: + if not args.use_combiner: + raise ValueError("Combiner not use here.") + if args.lambda_c != 0: + loss_c = nn.CrossEntropyLoss()(outs[name], labels) + loss += args.lambda_c * loss_c + batch_loss += loss_c.item() * args.lambda_c + + elif "ori_out" in name: + loss_ori = F.cross_entropy(outs[name], labels) + loss += loss_ori + batch_loss += loss_ori.item() + + # 梯度累积:还原真实Loss + loss /= args.update_freq + batch_real_loss = loss.item() * args.update_freq # 当前batch的真实Loss + epoch_total_loss += batch_real_loss + + ### 反向传播与参数更新 + if args.use_amp: + scaler.scale(loss).backward() + else: + loss.backward() + + if (batch_id + 1) % args.update_freq == 0: + if args.use_amp: + scaler.step(optimizer) + scaler.update() + else: + optimizer.step() + optimizer.zero_grad() + + ### 计算并打印当前batch的训练Acc和Loss + if (batch_id + 1) % args.log_freq == 0: + # 计算当前batch的Acc(优先取combiner,无则取原始输出) + if args.use_combiner and "comb_outs" in outs: + pred = torch.argmax(outs["comb_outs"], dim=1) + elif "ori_out" in outs: + pred = torch.argmax(outs["ori_out"], dim=1) + else: + # 若以上都没有,取最后一个FPN层的输出 + pred = torch.argmax(outs["layer4"].mean(1), dim=1) + + # 计算当前batch的正确数和Acc + batch_correct = (pred == labels).sum().item() + batch_acc = (batch_correct / batch_size) * 100 # 转换为百分比 + epoch_correct += batch_correct # 累加至总正确数 + + # 打印batch级指标 + print(f"[Train] Epoch {epoch+1:2d} | Batch {batch_id+1:4d}/{total_batchs:4d} | " + f"Loss: {batch_real_loss:.4f} | Acc: {batch_acc:.2f}%") + + ### 显示训练进度 + train_progress = (batch_id + 1) / total_batchs + if train_progress > show_progress[progress_i]: + print(".." + str(int(show_progress[progress_i] * 100)) + "%", end='', flush=True) + progress_i += 1 + + ### 打印当前epoch的训练汇总 + avg_train_loss = epoch_total_loss / total_batchs # 平均Loss(按batch数) + avg_train_acc = (epoch_correct / total_samples) * 100 # 平均Acc(按样本数) + print(f"\n[Train Summary] Epoch {epoch+1:2d} | Avg Loss: {avg_train_loss:.4f} | Avg Acc: {avg_train_acc:.2f}%") + + +def main(args, tlogger): + train_loader, val_loader, model, optimizer, schedule, scaler, amp_context, start_epoch = set_environment(args, tlogger) + + best_acc = 0.0 + best_eval_name = "null" + + if args.use_wandb: + wandb.init(entity=args.wandb_entity, project=args.project_name, name=args.exp_name, config=args) + wandb.run.summary["best_acc"] = best_acc + wandb.run.summary["best_epoch"] = 0 + + for epoch in range(start_epoch, args.max_epochs): + ### 训练阶段 + if train_loader is not None: + tlogger.print("Start Training {} Epoch".format(epoch + 1)) + train(args, epoch, model, scaler, amp_context, optimizer, schedule, train_loader) + tlogger.print() + else: + from eval import eval_and_save + eval_and_save(args, model, val_loader) + break + + ### 调整验证频率 + eval_freq_schedule(args, epoch) + + ### 保存最新模型 + model_to_save = model.module if hasattr(model, "module") else model + checkpoint = {"model": model_to_save.state_dict(), "optimizer": optimizer.state_dict(), "epoch": epoch} + torch.save(checkpoint, args.save_dir + "backup/last.pt") + + ### 验证阶段(恢复原始逻辑,不计算验证Loss) + if epoch == 0 or (epoch + 1) % args.eval_freq == 0: + acc = -1 + if val_loader is not None: + tlogger.print("Start Evaluating {} Epoch".format(epoch + 1)) + # 恢复原始evaluate调用(仅返回3个值) + acc, eval_name, accs = evaluate(args, model, val_loader) + # 打印验证Acc(不含Loss) + print(f"[Val] Epoch {epoch+1:2d} | Best Acc: {max(acc, best_acc):.2f}% (Current Acc: {acc:.2f}%)") + tlogger.print() + + ### 更新wandb日志 + if args.use_wandb: + wandb.log(accs) + + ### 更新最佳模型 + if acc > best_acc: + best_acc = acc + best_eval_name = eval_name + torch.save(checkpoint, args.save_dir + "backup/best.pt") + print(f"[Update Best Model] Epoch {epoch+1:2d} | Best Acc: {best_acc:.2f}%") + + ### 更新wandb摘要 + if args.use_wandb: + wandb.run.summary["best_acc"] = best_acc + wandb.run.summary["best_epoch"] = epoch + 1 + + +if __name__ == "__main__": + tlogger = timeLogger() + tlogger.print("Reading Config...") + args = get_args() + assert args.c != "", "Please provide config file (.yaml)" + load_yaml(args, args.c) + if not hasattr(args, "log_freq"): + args.log_freq = 10 # 默认每10个batch打印一次 + build_record_folder(args) + tlogger.print() + main(args, tlogger) diff --git a/.ipynb_checkpoints/vit_pim_main_sgd-checkpoint.py b/.ipynb_checkpoints/vit_pim_main_sgd-checkpoint.py new file mode 100644 index 0000000..1065083 --- /dev/null +++ b/.ipynb_checkpoints/vit_pim_main_sgd-checkpoint.py @@ -0,0 +1,428 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import contextlib +import wandb +import warnings + +from models.builder import MODEL_GETTER +from data.dataset import build_loader +from utils.costom_logger import timeLogger +from utils.config_utils import load_yaml, build_record_folder, get_args +from utils.lr_schedule import cosine_decay, adjust_lr, get_lr +from eval import evaluate, cal_train_metrics + +warnings.simplefilter("ignore") + + +def eval_freq_schedule(args, epoch: int): + """ + 根据当前训练的 epoch 调整验证频率(eval_freq)。 + 在训练接近尾声时更频繁地进行验证,以便更好地监控模型性能。 + + 参数: + args: 包含训练配置参数的对象,其中包括 eval_freq 和 max_epochs。 + epoch: 当前训练的 epoch 数。 + """ + # 如果当前 epoch 大于等于最大训练轮次的 95%,则将验证频率设为 1(每个 epoch 都验证) + if epoch >= args.max_epochs * 0.95: + args.eval_freq = 1 + # 如果当前 epoch 大于等于最大训练轮次的 90% 但小于 95%,同样将验证频率设为 1 + elif epoch >= args.max_epochs * 0.9: + args.eval_freq = 1 + # 如果当前 epoch 大于等于最大训练轮次的 80% 但小于 90%,将验证频率设为 2(每两个 epoch 验证一次) + elif epoch >= args.max_epochs * 0.8: + args.eval_freq = 2 + + +def set_environment(args, tlogger): + """ + 设置训练环境,包括设备、数据加载器、模型、优化器等。 + + 参数: + args: 包含训练配置参数的对象。 + tlogger: 用于记录时间日志的对象。 + + 返回: + train_loader: 训练数据加载器。 + val_loader: 验证数据加载器。 + model: 构建并初始化的模型。 + optimizer: 优化器(如果仅评估则为None)。 + schedule: 学习率调度器(如果仅评估则为None)。 + scaler: AMP缩放器(如果不使用AMP则为None)。 + amp_context: AMP上下文管理器(如果不使用AMP则是nullcontext)。 + start_epoch: 训练开始的epoch数(如果有预训练模型,则从该模型的epoch开始)。 + """ + + print("Setting Environment...") + + # 设置训练设备:如果CUDA可用则使用GPU,否则使用CPU + args.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + + ### = = = = Dataset and Data Loader = = = = + # 构建训练和验证数据加载器 + tlogger.print("Building Dataloader....") + + train_loader, val_loader = build_loader(args) + + # 检查是否成功构建了数据加载器 + if train_loader is None and val_loader is None: + raise ValueError("Find nothing to train or evaluate.") + + # 打印训练集信息 + if train_loader is not None: + print(" Train Samples: {} (batch: {})".format(len(train_loader.dataset), len(train_loader))) + else: + # raise ValueError("Build train loader fail, please provide legal path.") + print(" Train Samples: 0 ~~~~~> [Only Evaluation]") + + # 打印验证集信息 + if val_loader is not None: + print(" Validation Samples: {} (batch: {})".format(len(val_loader.dataset), len(val_loader))) + else: + print(" Validation Samples: 0 ~~~~~> [Only Training]") + tlogger.print() + + ### = = = = Model = = = = + # 构建模型 + tlogger.print("Building Model....") + model = MODEL_GETTER[args.model_name]( + use_fpn=args.use_fpn, + fpn_size=args.fpn_size, + use_selection=args.use_selection, + num_classes=args.num_classes, + num_selects=args.num_selects, + use_combiner=args.use_combiner, + ) # about return_nodes, we use our default setting + + # 如果提供了预训练模型,则加载权重 + if args.pretrained is not None: + checkpoint = torch.load(args.pretrained, map_location=torch.device('cpu')) + model.load_state_dict(checkpoint['model']) + start_epoch = checkpoint['epoch'] + print(start_epoch) + else: + start_epoch = 0 + + # 将模型移动到指定设备 + model.to(args.device) + tlogger.print() + + """ + 如果你有多GPU设备,可以在单机多GPU情况下使用torch.nn.DataParallel, + 或者使用torch.nn.parallel.DistributedDataParallel实现多进程并行。 + 更多详情:https://pytorch.org/tutorials/beginner/dist_overview.html + """ + + # 如果没有训练数据加载器,只进行评估,返回部分对象 + if train_loader is None: + return train_loader, val_loader, model, None, None, None, None, start_epoch + + ### = = = = Optimizer = = = = + # 构建优化器 + tlogger.print("Building Optimizer....") + if args.optimizer == "SGD": + optimizer = torch.optim.SGD(model.parameters(), lr=args.max_lr, nesterov=True, momentum=0.9, + weight_decay=args.wdecay) + elif args.optimizer == "AdamW": + optimizer = torch.optim.AdamW(model.parameters(), lr=args.max_lr) + + # 如果有预训练模型,加载优化器状态 + if args.pretrained is not None: + optimizer.load_state_dict(checkpoint['optimizer']) + + tlogger.print() + + # 构建学习率调度器 + schedule = cosine_decay(args, len(train_loader)) + + # 如果使用混合精度训练(AMP),设置相关的组件 + if args.use_amp: + scaler = torch.cuda.amp.GradScaler() + amp_context = torch.cuda.amp.autocast + else: + scaler = None + amp_context = contextlib.nullcontext + + # 返回所有构建的组件 + return train_loader, val_loader, model, optimizer, schedule, scaler, amp_context, start_epoch + + +def train(args, epoch, model, scaler, amp_context, optimizer, schedule, train_loader): + """ + 训练函数,在每个epoch中迭代训练数据并更新模型参数。 + + 参数: + args: 包含训练配置参数的对象。 + epoch: 当前训练的 epoch 数。 + model: 要训练的模型。 + scaler: AMP缩放器(如果不使用AMP则为None)。 + amp_context: AMP上下文管理器(如果不使用AMP则是nullcontext)。 + optimizer: 优化器。 + schedule: 学习率调度器。 + train_loader: 训练数据加载器。 + """ + + # 清空优化器的梯度 + optimizer.zero_grad() + + # 获取总批次数,仅用于日志记录 + total_batchs = len(train_loader) + + # 定义训练进度显示点(0%, 10%, ..., 100%) + show_progress = [x / 10 for x in range(11)] + progress_i = 0 + + # 遍历训练数据加载器中的每个批次 + for batch_id, (ids, datas, labels) in enumerate(train_loader): + # 设置模型为训练模式 + model.train() + + """ = = = = adjust learning rate = = = = """ + # 计算当前迭代次数 + iterations = epoch * len(train_loader) + batch_id + # 调整学习率 + adjust_lr(iterations, optimizer, schedule) + + # 获取当前批次的样本数量 + batch_size = labels.size(0) + + """ = = = = forward and calculate loss = = = = """ + # 将数据和标签移动到指定设备 + datas, labels = datas.to(args.device), labels.to(args.device) + + # 使用AMP上下文进行前向传播(如果启用AMP) + with amp_context(): + """ + [Model Return] + FPN + Selector + Combiner --> return 'layer1', 'layer2', 'layer3', 'layer4', ...(depend on your setting) + 'preds_0', 'preds_1', 'comb_outs' + FPN + Selector --> return 'layer1', 'layer2', 'layer3', 'layer4', ...(depend on your setting) + 'preds_0', 'preds_1' + FPN --> return 'layer1', 'layer2', 'layer3', 'layer4' (depend on your setting) + ~ --> return 'ori_out' + + [Retuen Tensor] + 'preds_0': logit has not been selected by Selector. + 'preds_1': logit has been selected by Selector. + 'comb_outs': The prediction of combiner. + """ + # 前向传播获取输出 + outs = model(datas) + + # 初始化总损失 + loss = 0. + + # 遍历模型输出的各个部分,计算相应的损失 + for name in outs: + # 处理选择器的输出 + if "select_" in name: + if not args.use_selection: + raise ValueError("Selector not use here.") + if args.lambda_s != 0: + # 计算选择器损失 + S = outs[name].size(1) + logit = outs[name].view(-1, args.num_classes).contiguous() + loss_s = nn.CrossEntropyLoss()(logit, + labels.unsqueeze(1).repeat(1, S).flatten(0)) + loss += args.lambda_s * loss_s + else: + loss_s = 0.0 + + # 处理丢弃部分的输出 + elif "drop_" in name: + if not args.use_selection: + raise ValueError("Selector not use here.") + + if args.lambda_n != 0: + # 计算负样本损失 + S = outs[name].size(1) + logit = outs[name].view(-1, args.num_classes).contiguous() + n_preds = nn.Tanh()(logit) + labels_0 = torch.zeros([batch_size * S, args.num_classes]) - 1 + labels_0 = labels_0.to(args.device) + loss_n = nn.MSELoss()(n_preds, labels_0) + loss += args.lambda_n * loss_n + else: + loss_n = 0.0 + + # 处理FPN层的输出 + elif "layer" in name: + if not args.use_fpn: + raise ValueError("FPN not use here.") + if args.lambda_b != 0: + # 计算FPN基础损失 + ### here using 'layer1'~'layer4' is default setting, you can change to your own + loss_b = nn.CrossEntropyLoss()(outs[name].mean(1), labels) + loss += args.lambda_b * loss_b + else: + loss_b = 0.0 + + # 处理组合器的输出 + elif "comb_outs" in name: + if not args.use_combiner: + raise ValueError("Combiner not use here.") + + if args.lambda_c != 0: + # 计算组合器损失 + loss_c = nn.CrossEntropyLoss()(outs[name], labels) + loss += args.lambda_c * loss_c + + # 处理原始输出 + elif "ori_out" in name: + # 计算原始输出损失 + loss_ori = F.cross_entropy(outs[name], labels) + loss += loss_ori + + # 对损失进行平均化处理 + loss /= args.update_freq + + """ = = = = calculate gradient = = = = """ + # 计算梯度(根据是否使用AMP选择不同的方式) + if args.use_amp: + scaler.scale(loss).backward() + else: + loss.backward() + + """ = = = = update model = = = = """ + # 更新模型参数(每隔update_freq个批次更新一次) + if (batch_id + 1) % args.update_freq == 0: + if args.use_amp: + # 使用AMP更新模型 + scaler.step(optimizer) + scaler.update() # next batch + else: + # 正常更新模型 + optimizer.step() + # 清空梯度 + optimizer.zero_grad() + + """ log (MISC) """ + # 记录训练日志(如果启用wandb且达到记录频率) + if args.use_wandb and ((batch_id + 1) % args.log_freq == 0): + # 切换到评估模式进行日志记录 + model.eval() + msg = {} + msg['info/epoch'] = epoch + 1 + msg['info/lr'] = get_lr(optimizer) + # 计算并记录训练指标 + cal_train_metrics(args, msg, outs, labels, batch_size) + # 将日志信息发送到wandb + wandb.log(msg) + + # 显示训练进度 + train_progress = (batch_id + 1) / total_batchs + # print(train_progress, show_progress[progress_i]) + if train_progress > show_progress[progress_i]: + print(".." + str(int(show_progress[progress_i] * 100)) + "%", end='', flush=True) + progress_i += 1 + +def main(args, tlogger): + """ + 主训练循环函数,负责整个训练和验证过程,包括模型保存(last.pt 和 best.pt)。 + + 参数: + args: 包含训练配置参数的对象。 + tlogger: 用于记录时间日志的对象。 + """ + + # 调用set_environment函数设置训练环境,获取数据加载器、模型、优化器等 + train_loader, val_loader, model, optimizer, schedule, scaler, amp_context, start_epoch = set_environment(args, + tlogger) + + # 初始化最佳准确率和最佳评估名称 + best_acc = 0.0 + best_eval_name = "null" + + # 如果启用wandb,则初始化wandb项目并设置初始摘要信息 + if args.use_wandb: + wandb.init(entity=args.wandb_entity, + project=args.project_name, + name=args.exp_name, + config=args) + wandb.run.summary["best_acc"] = best_acc + wandb.run.summary["best_eval_name"] = best_eval_name + wandb.run.summary["best_epoch"] = 0 + + # 开始训练循环,从start_epoch到max_epochs + for epoch in range(start_epoch, args.max_epochs): + + """ + 训练阶段 + """ + # 如果存在训练数据加载器,则进行训练 + if train_loader is not None: + tlogger.print("Start Training {} Epoch".format(epoch + 1)) + # 调用train函数进行一个epoch的训练 + train(args, epoch, model, scaler, amp_context, optimizer, schedule, train_loader) + tlogger.print() + else: + # 如果没有训练数据加载器(仅评估模式),则调用eval_and_save进行评估并保存结果,然后退出循环 + from eval import eval_and_save + eval_and_save(args, model, val_loader) + break + + # 根据当前epoch调整验证频率 + eval_freq_schedule(args, epoch) + + # 准备要保存的模型检查点(处理多GPU情况) + model_to_save = model.module if hasattr(model, "module") else model + checkpoint = {"model": model_to_save.state_dict(), "optimizer": optimizer.state_dict(), "epoch": epoch} + # 保存最新的模型检查点 + torch.save(checkpoint, args.save_dir + "backup/last.pt") + + # 根据评估频率进行验证(每个epoch或每隔几个epoch) + if epoch == 0 or (epoch + 1) % args.eval_freq == 0: + """ + 验证阶段 + """ + acc = -1 + # 如果存在验证数据加载器,则进行验证 + if val_loader is not None: + tlogger.print("Start Evaluating {} Epoch".format(epoch + 1)) + # 调用evaluate函数进行验证,获取准确率等信息 + acc, eval_name, accs = evaluate(args, model, val_loader) + # 打印当前验证结果和历史最佳准确率 + tlogger.print("....BEST_ACC: {}% ({}%)".format(max(acc, best_acc), acc)) + tlogger.print() + + # 如果启用wandb,则记录验证指标 + if args.use_wandb: + wandb.log(accs) + + # 如果当前准确率优于历史最佳准确率,则更新最佳准确率并保存最佳模型 + if acc > best_acc: + best_acc = acc + best_eval_name = eval_name + torch.save(checkpoint, args.save_dir + "backup/best.pt") + # 如果启用wandb,则更新wandb中的最佳指标摘要 + if args.use_wandb: + wandb.run.summary["best_acc"] = best_acc + wandb.run.summary["best_eval_name"] = best_eval_name + wandb.run.summary["best_epoch"] = epoch + 1 + + +if __name__ == "__main__": + # 创建一个时间记录器实例,用于记录和打印时间相关的日志 + tlogger = timeLogger() + + # 打印正在读取配置文件的信息 + tlogger.print("Reading Config...") + + # 获取命令行参数,这些参数包括配置文件路径等 + args = get_args() + + # 断言确保提供了配置文件(.yaml格式),如果没有提供则抛出错误信息 + assert args.c != "", "Please provide config file (.yaml)" + + # 加载指定的YAML配置文件,并将配置内容存入args对象中 + load_yaml(args, args.c) + + # 根据配置创建记录文件夹,用于保存训练过程中的日志、模型等文件 + build_record_folder(args) + + # 打印空行,起到分隔日志的作用 + tlogger.print() + + # 调用main函数开始执行主要的训练或评估流程,传入解析后的参数和时间记录器 + main(args, tlogger) \ No newline at end of file diff --git a/SwinT_pim_main_sgd.py b/SwinT_pim_main_sgd.py new file mode 100644 index 0000000..1065083 --- /dev/null +++ b/SwinT_pim_main_sgd.py @@ -0,0 +1,428 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import contextlib +import wandb +import warnings + +from models.builder import MODEL_GETTER +from data.dataset import build_loader +from utils.costom_logger import timeLogger +from utils.config_utils import load_yaml, build_record_folder, get_args +from utils.lr_schedule import cosine_decay, adjust_lr, get_lr +from eval import evaluate, cal_train_metrics + +warnings.simplefilter("ignore") + + +def eval_freq_schedule(args, epoch: int): + """ + 根据当前训练的 epoch 调整验证频率(eval_freq)。 + 在训练接近尾声时更频繁地进行验证,以便更好地监控模型性能。 + + 参数: + args: 包含训练配置参数的对象,其中包括 eval_freq 和 max_epochs。 + epoch: 当前训练的 epoch 数。 + """ + # 如果当前 epoch 大于等于最大训练轮次的 95%,则将验证频率设为 1(每个 epoch 都验证) + if epoch >= args.max_epochs * 0.95: + args.eval_freq = 1 + # 如果当前 epoch 大于等于最大训练轮次的 90% 但小于 95%,同样将验证频率设为 1 + elif epoch >= args.max_epochs * 0.9: + args.eval_freq = 1 + # 如果当前 epoch 大于等于最大训练轮次的 80% 但小于 90%,将验证频率设为 2(每两个 epoch 验证一次) + elif epoch >= args.max_epochs * 0.8: + args.eval_freq = 2 + + +def set_environment(args, tlogger): + """ + 设置训练环境,包括设备、数据加载器、模型、优化器等。 + + 参数: + args: 包含训练配置参数的对象。 + tlogger: 用于记录时间日志的对象。 + + 返回: + train_loader: 训练数据加载器。 + val_loader: 验证数据加载器。 + model: 构建并初始化的模型。 + optimizer: 优化器(如果仅评估则为None)。 + schedule: 学习率调度器(如果仅评估则为None)。 + scaler: AMP缩放器(如果不使用AMP则为None)。 + amp_context: AMP上下文管理器(如果不使用AMP则是nullcontext)。 + start_epoch: 训练开始的epoch数(如果有预训练模型,则从该模型的epoch开始)。 + """ + + print("Setting Environment...") + + # 设置训练设备:如果CUDA可用则使用GPU,否则使用CPU + args.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + + ### = = = = Dataset and Data Loader = = = = + # 构建训练和验证数据加载器 + tlogger.print("Building Dataloader....") + + train_loader, val_loader = build_loader(args) + + # 检查是否成功构建了数据加载器 + if train_loader is None and val_loader is None: + raise ValueError("Find nothing to train or evaluate.") + + # 打印训练集信息 + if train_loader is not None: + print(" Train Samples: {} (batch: {})".format(len(train_loader.dataset), len(train_loader))) + else: + # raise ValueError("Build train loader fail, please provide legal path.") + print(" Train Samples: 0 ~~~~~> [Only Evaluation]") + + # 打印验证集信息 + if val_loader is not None: + print(" Validation Samples: {} (batch: {})".format(len(val_loader.dataset), len(val_loader))) + else: + print(" Validation Samples: 0 ~~~~~> [Only Training]") + tlogger.print() + + ### = = = = Model = = = = + # 构建模型 + tlogger.print("Building Model....") + model = MODEL_GETTER[args.model_name]( + use_fpn=args.use_fpn, + fpn_size=args.fpn_size, + use_selection=args.use_selection, + num_classes=args.num_classes, + num_selects=args.num_selects, + use_combiner=args.use_combiner, + ) # about return_nodes, we use our default setting + + # 如果提供了预训练模型,则加载权重 + if args.pretrained is not None: + checkpoint = torch.load(args.pretrained, map_location=torch.device('cpu')) + model.load_state_dict(checkpoint['model']) + start_epoch = checkpoint['epoch'] + print(start_epoch) + else: + start_epoch = 0 + + # 将模型移动到指定设备 + model.to(args.device) + tlogger.print() + + """ + 如果你有多GPU设备,可以在单机多GPU情况下使用torch.nn.DataParallel, + 或者使用torch.nn.parallel.DistributedDataParallel实现多进程并行。 + 更多详情:https://pytorch.org/tutorials/beginner/dist_overview.html + """ + + # 如果没有训练数据加载器,只进行评估,返回部分对象 + if train_loader is None: + return train_loader, val_loader, model, None, None, None, None, start_epoch + + ### = = = = Optimizer = = = = + # 构建优化器 + tlogger.print("Building Optimizer....") + if args.optimizer == "SGD": + optimizer = torch.optim.SGD(model.parameters(), lr=args.max_lr, nesterov=True, momentum=0.9, + weight_decay=args.wdecay) + elif args.optimizer == "AdamW": + optimizer = torch.optim.AdamW(model.parameters(), lr=args.max_lr) + + # 如果有预训练模型,加载优化器状态 + if args.pretrained is not None: + optimizer.load_state_dict(checkpoint['optimizer']) + + tlogger.print() + + # 构建学习率调度器 + schedule = cosine_decay(args, len(train_loader)) + + # 如果使用混合精度训练(AMP),设置相关的组件 + if args.use_amp: + scaler = torch.cuda.amp.GradScaler() + amp_context = torch.cuda.amp.autocast + else: + scaler = None + amp_context = contextlib.nullcontext + + # 返回所有构建的组件 + return train_loader, val_loader, model, optimizer, schedule, scaler, amp_context, start_epoch + + +def train(args, epoch, model, scaler, amp_context, optimizer, schedule, train_loader): + """ + 训练函数,在每个epoch中迭代训练数据并更新模型参数。 + + 参数: + args: 包含训练配置参数的对象。 + epoch: 当前训练的 epoch 数。 + model: 要训练的模型。 + scaler: AMP缩放器(如果不使用AMP则为None)。 + amp_context: AMP上下文管理器(如果不使用AMP则是nullcontext)。 + optimizer: 优化器。 + schedule: 学习率调度器。 + train_loader: 训练数据加载器。 + """ + + # 清空优化器的梯度 + optimizer.zero_grad() + + # 获取总批次数,仅用于日志记录 + total_batchs = len(train_loader) + + # 定义训练进度显示点(0%, 10%, ..., 100%) + show_progress = [x / 10 for x in range(11)] + progress_i = 0 + + # 遍历训练数据加载器中的每个批次 + for batch_id, (ids, datas, labels) in enumerate(train_loader): + # 设置模型为训练模式 + model.train() + + """ = = = = adjust learning rate = = = = """ + # 计算当前迭代次数 + iterations = epoch * len(train_loader) + batch_id + # 调整学习率 + adjust_lr(iterations, optimizer, schedule) + + # 获取当前批次的样本数量 + batch_size = labels.size(0) + + """ = = = = forward and calculate loss = = = = """ + # 将数据和标签移动到指定设备 + datas, labels = datas.to(args.device), labels.to(args.device) + + # 使用AMP上下文进行前向传播(如果启用AMP) + with amp_context(): + """ + [Model Return] + FPN + Selector + Combiner --> return 'layer1', 'layer2', 'layer3', 'layer4', ...(depend on your setting) + 'preds_0', 'preds_1', 'comb_outs' + FPN + Selector --> return 'layer1', 'layer2', 'layer3', 'layer4', ...(depend on your setting) + 'preds_0', 'preds_1' + FPN --> return 'layer1', 'layer2', 'layer3', 'layer4' (depend on your setting) + ~ --> return 'ori_out' + + [Retuen Tensor] + 'preds_0': logit has not been selected by Selector. + 'preds_1': logit has been selected by Selector. + 'comb_outs': The prediction of combiner. + """ + # 前向传播获取输出 + outs = model(datas) + + # 初始化总损失 + loss = 0. + + # 遍历模型输出的各个部分,计算相应的损失 + for name in outs: + # 处理选择器的输出 + if "select_" in name: + if not args.use_selection: + raise ValueError("Selector not use here.") + if args.lambda_s != 0: + # 计算选择器损失 + S = outs[name].size(1) + logit = outs[name].view(-1, args.num_classes).contiguous() + loss_s = nn.CrossEntropyLoss()(logit, + labels.unsqueeze(1).repeat(1, S).flatten(0)) + loss += args.lambda_s * loss_s + else: + loss_s = 0.0 + + # 处理丢弃部分的输出 + elif "drop_" in name: + if not args.use_selection: + raise ValueError("Selector not use here.") + + if args.lambda_n != 0: + # 计算负样本损失 + S = outs[name].size(1) + logit = outs[name].view(-1, args.num_classes).contiguous() + n_preds = nn.Tanh()(logit) + labels_0 = torch.zeros([batch_size * S, args.num_classes]) - 1 + labels_0 = labels_0.to(args.device) + loss_n = nn.MSELoss()(n_preds, labels_0) + loss += args.lambda_n * loss_n + else: + loss_n = 0.0 + + # 处理FPN层的输出 + elif "layer" in name: + if not args.use_fpn: + raise ValueError("FPN not use here.") + if args.lambda_b != 0: + # 计算FPN基础损失 + ### here using 'layer1'~'layer4' is default setting, you can change to your own + loss_b = nn.CrossEntropyLoss()(outs[name].mean(1), labels) + loss += args.lambda_b * loss_b + else: + loss_b = 0.0 + + # 处理组合器的输出 + elif "comb_outs" in name: + if not args.use_combiner: + raise ValueError("Combiner not use here.") + + if args.lambda_c != 0: + # 计算组合器损失 + loss_c = nn.CrossEntropyLoss()(outs[name], labels) + loss += args.lambda_c * loss_c + + # 处理原始输出 + elif "ori_out" in name: + # 计算原始输出损失 + loss_ori = F.cross_entropy(outs[name], labels) + loss += loss_ori + + # 对损失进行平均化处理 + loss /= args.update_freq + + """ = = = = calculate gradient = = = = """ + # 计算梯度(根据是否使用AMP选择不同的方式) + if args.use_amp: + scaler.scale(loss).backward() + else: + loss.backward() + + """ = = = = update model = = = = """ + # 更新模型参数(每隔update_freq个批次更新一次) + if (batch_id + 1) % args.update_freq == 0: + if args.use_amp: + # 使用AMP更新模型 + scaler.step(optimizer) + scaler.update() # next batch + else: + # 正常更新模型 + optimizer.step() + # 清空梯度 + optimizer.zero_grad() + + """ log (MISC) """ + # 记录训练日志(如果启用wandb且达到记录频率) + if args.use_wandb and ((batch_id + 1) % args.log_freq == 0): + # 切换到评估模式进行日志记录 + model.eval() + msg = {} + msg['info/epoch'] = epoch + 1 + msg['info/lr'] = get_lr(optimizer) + # 计算并记录训练指标 + cal_train_metrics(args, msg, outs, labels, batch_size) + # 将日志信息发送到wandb + wandb.log(msg) + + # 显示训练进度 + train_progress = (batch_id + 1) / total_batchs + # print(train_progress, show_progress[progress_i]) + if train_progress > show_progress[progress_i]: + print(".." + str(int(show_progress[progress_i] * 100)) + "%", end='', flush=True) + progress_i += 1 + +def main(args, tlogger): + """ + 主训练循环函数,负责整个训练和验证过程,包括模型保存(last.pt 和 best.pt)。 + + 参数: + args: 包含训练配置参数的对象。 + tlogger: 用于记录时间日志的对象。 + """ + + # 调用set_environment函数设置训练环境,获取数据加载器、模型、优化器等 + train_loader, val_loader, model, optimizer, schedule, scaler, amp_context, start_epoch = set_environment(args, + tlogger) + + # 初始化最佳准确率和最佳评估名称 + best_acc = 0.0 + best_eval_name = "null" + + # 如果启用wandb,则初始化wandb项目并设置初始摘要信息 + if args.use_wandb: + wandb.init(entity=args.wandb_entity, + project=args.project_name, + name=args.exp_name, + config=args) + wandb.run.summary["best_acc"] = best_acc + wandb.run.summary["best_eval_name"] = best_eval_name + wandb.run.summary["best_epoch"] = 0 + + # 开始训练循环,从start_epoch到max_epochs + for epoch in range(start_epoch, args.max_epochs): + + """ + 训练阶段 + """ + # 如果存在训练数据加载器,则进行训练 + if train_loader is not None: + tlogger.print("Start Training {} Epoch".format(epoch + 1)) + # 调用train函数进行一个epoch的训练 + train(args, epoch, model, scaler, amp_context, optimizer, schedule, train_loader) + tlogger.print() + else: + # 如果没有训练数据加载器(仅评估模式),则调用eval_and_save进行评估并保存结果,然后退出循环 + from eval import eval_and_save + eval_and_save(args, model, val_loader) + break + + # 根据当前epoch调整验证频率 + eval_freq_schedule(args, epoch) + + # 准备要保存的模型检查点(处理多GPU情况) + model_to_save = model.module if hasattr(model, "module") else model + checkpoint = {"model": model_to_save.state_dict(), "optimizer": optimizer.state_dict(), "epoch": epoch} + # 保存最新的模型检查点 + torch.save(checkpoint, args.save_dir + "backup/last.pt") + + # 根据评估频率进行验证(每个epoch或每隔几个epoch) + if epoch == 0 or (epoch + 1) % args.eval_freq == 0: + """ + 验证阶段 + """ + acc = -1 + # 如果存在验证数据加载器,则进行验证 + if val_loader is not None: + tlogger.print("Start Evaluating {} Epoch".format(epoch + 1)) + # 调用evaluate函数进行验证,获取准确率等信息 + acc, eval_name, accs = evaluate(args, model, val_loader) + # 打印当前验证结果和历史最佳准确率 + tlogger.print("....BEST_ACC: {}% ({}%)".format(max(acc, best_acc), acc)) + tlogger.print() + + # 如果启用wandb,则记录验证指标 + if args.use_wandb: + wandb.log(accs) + + # 如果当前准确率优于历史最佳准确率,则更新最佳准确率并保存最佳模型 + if acc > best_acc: + best_acc = acc + best_eval_name = eval_name + torch.save(checkpoint, args.save_dir + "backup/best.pt") + # 如果启用wandb,则更新wandb中的最佳指标摘要 + if args.use_wandb: + wandb.run.summary["best_acc"] = best_acc + wandb.run.summary["best_eval_name"] = best_eval_name + wandb.run.summary["best_epoch"] = epoch + 1 + + +if __name__ == "__main__": + # 创建一个时间记录器实例,用于记录和打印时间相关的日志 + tlogger = timeLogger() + + # 打印正在读取配置文件的信息 + tlogger.print("Reading Config...") + + # 获取命令行参数,这些参数包括配置文件路径等 + args = get_args() + + # 断言确保提供了配置文件(.yaml格式),如果没有提供则抛出错误信息 + assert args.c != "", "Please provide config file (.yaml)" + + # 加载指定的YAML配置文件,并将配置内容存入args对象中 + load_yaml(args, args.c) + + # 根据配置创建记录文件夹,用于保存训练过程中的日志、模型等文件 + build_record_folder(args) + + # 打印空行,起到分隔日志的作用 + tlogger.print() + + # 调用main函数开始执行主要的训练或评估流程,传入解析后的参数和时间记录器 + main(args, tlogger) \ No newline at end of file diff --git a/__pycache__/eval.cpython-38.pyc b/__pycache__/eval.cpython-38.pyc new file mode 100644 index 0000000..53b1785 Binary files /dev/null and b/__pycache__/eval.cpython-38.pyc differ diff --git a/configs/AICUP_SwinT.yaml b/configs/AICUP_SwinT.yaml deleted file mode 100644 index 14284f4..0000000 --- a/configs/AICUP_SwinT.yaml +++ /dev/null @@ -1,33 +0,0 @@ -project_name: AICUP22-Flowers -exp_name: T01 -use_wandb: True -wandb_entity: poyung -train_root: ./aicup_data -val_root: ~ -data_size: 384 -num_workers: 2 -batch_size: 64 -model_name: swin-t -optimizer: SGD -max_lr: 0.0005 -wdecay: 0.0005 -max_epochs: 50 -warmup_batchs: 0 -use_amp: True -use_fpn: True -fpn_size: 512 -use_selection: True -num_classes: 10 -num_selects: - layer1: 32 - layer2: 32 - layer3: 32 - layer4: 32 -use_combiner: True -lambda_b: 0.5 -lambda_s: 0.0 -lambda_n: 5.0 -lambda_c: 1.0 -update_freq: 1 -log_freq: 100 -eval_freq: 10 \ No newline at end of file diff --git a/configs/AIC_PIM_SwinT.yaml b/configs/AIC_PIM_SwinT.yaml new file mode 100644 index 0000000..ce4dc6a --- /dev/null +++ b/configs/AIC_PIM_SwinT.yaml @@ -0,0 +1,37 @@ +project_name: aic-pim-swinT +exp_name: swinT_sgd +use_wandb: False +wandb_entity: professional_team +train_root: /root/autodl-tmp/webfg400_train/train/ +val_root: ~ +val_split: 0.2 +test_root: /root/autodl-tmp/RemoteSMUAIC/data/test_images/test_images +data_size: 384 +num_workers: 8 +batch_size: 16 +model_name: swin-t +pretrained: ~ +optimizer: SGD +max_lr: 0.0005 +wdecay: 0.0005 +max_epochs: 70 +warmup_batchs: 800 +use_amp: True +use_fpn: True +fpn_size: 512 +use_selection: True +num_classes: 200 +num_selects: + layer1: 256 + layer2: 128 + layer3: 64 + layer4: 32 +use_combiner: True +lambda_b: 0.5 +lambda_s: 0.0 +lambda_n: 5.0 +lambda_c: 1.0 +update_freq: 2 +log_freq: 100 +eval_freq: 1 +save_dir: \ No newline at end of file diff --git a/configs/SMUAIC.yaml b/configs/SMUAIC.yaml new file mode 100644 index 0000000..da64341 --- /dev/null +++ b/configs/SMUAIC.yaml @@ -0,0 +1,37 @@ +project_name: SMUAIC25 +exp_name: PrTrTst +use_wandb: False +wandb_entity: professional_team +train_root: /root/autodl-tmp/RemoteSMUAIC/data/train_images/train_images +val_root: ~ +val_split: 0.1 +test_root: /root/autodl-tmp/RemoteSMUAIC/data/test_images/test_images +data_size: 224 +num_workers: 2 +batch_size: 64 +model_name: resnet50 +pretrained: /root/autodl-tmp/records/SMUAIC25/PrTrTst/backup/best.pt +optimizer: SGD +max_lr: 0.0005 +wdecay: 0.0005 +max_epochs: 70 +warmup_batchs: 0 +use_amp: True +use_fpn: True +fpn_size: 512 +use_selection: True +num_classes: 400 +num_selects: + layer1: 32 + layer2: 32 + layer3: 32 + layer4: 32 +use_combiner: True +lambda_b: 0.5 +lambda_s: 0.0 +lambda_n: 5.0 +lambda_c: 1.0 +update_freq: 1 +log_freq: 100 +eval_freq: 1 +save_dir: /root/autodl-tmp/FGVC-PIM-master/records/ \ No newline at end of file diff --git a/data/.ipynb_checkpoints/dataset-checkpoint.py b/data/.ipynb_checkpoints/dataset-checkpoint.py new file mode 100644 index 0000000..de96aae --- /dev/null +++ b/data/.ipynb_checkpoints/dataset-checkpoint.py @@ -0,0 +1,160 @@ + +import os +import numpy as np +import cv2 +import torch +import torchvision.transforms as transforms +from PIL import Image +import copy +import torch +import warnings +from PIL import ImageFile +warnings.filterwarnings("ignore", category=UserWarning, module="PIL.Image") +ImageFile.LOAD_TRUNCATED_IMAGES = True +from .randaug import RandAugment + + +def build_loader(args): + train_set, train_loader, val_set, val_loader = None, None, None, None + + # 1. 加载完整的训练集(包含待划分的所有样本) + if args.train_root is not None: + full_train_set = ImageDataset(istrain=True, root=args.train_root, data_size=args.data_size, return_index=True) + + # 2. 判断是否需要随机划分验证集 + if args.val_root is not None: + # 情况1:有独立的val文件夹,直接加载 + train_set = full_train_set + val_set = ImageDataset(istrain=False, root=args.val_root, data_size=args.data_size, return_index=True) + else: + # 情况2:无独立val文件夹,从训练集中随机划分(默认20%为验证集,可在yaml中加val_split参数) + val_size = int(len(full_train_set) * args.val_split) # args.val_split需在yaml中定义,如0.2 + train_size = len(full_train_set) - val_size + # 用random_split划分,固定种子确保每次划分一致 + train_set, val_set = torch.utils.data.random_split( + full_train_set, [train_size, val_size], + generator=torch.Generator().manual_seed(42) # 固定种子,可复现 + ) + # 关键:验证集需用“测试/验证模式”的增强(无随机裁剪、翻转) + val_set.dataset.istrain = False # 把验证集的增强模式改成istrain=False + val_set.dataset.transforms = transforms.Compose([ # 重新赋值验证集的增强逻辑 + transforms.Resize((510, 510), Image.BILINEAR), + transforms.CenterCrop((args.data_size, args.data_size)), + transforms.ToTensor(), + transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + ]) + + # 3. 构建DataLoader + if train_set is not None: + train_loader = torch.utils.data.DataLoader( + train_set, num_workers=args.num_workers, shuffle=True, batch_size=args.batch_size + ) + if val_set is not None: + val_loader = torch.utils.data.DataLoader( + val_set, num_workers=1, shuffle=False, batch_size=args.batch_size # 验证集不shuffle + ) + + return train_loader, val_loader + +def get_dataset(args): + if args.train_root is not None: + train_set = ImageDataset(istrain=True, root=args.train_root, data_size=args.data_size, return_index=True) + return train_set + return None + +# 定义健壮加载器(放在ImageDataset类外) +_CORRUPTED_WARNED = set() +def robust_pil_loader(path): + try: + with Image.open(path) as img: + if img.mode == 'P': + img = img.convert('RGBA') + if img.mode != 'RGB': + img = img.convert('RGB') + img.load() + return img.copy() + except Exception as e: + if path not in _CORRUPTED_WARNED: + print(f"Warning: Skipping corrupted image {path}. Error: {e}") + _CORRUPTED_WARNED.add(path) + # 返回灰色占位图(尺寸与data_size一致) + return Image.new('RGB', (data_size, data_size), color=(128, 128, 128)) + +class ImageDataset(torch.utils.data.Dataset): + + def __init__(self, + istrain: bool, + root: str, + data_size: int, + return_index: bool = False): + # notice that: + # sub_data_size mean sub-image's width and height. + """ basic information """ + self.root = root + self.data_size = data_size + self.return_index = return_index + + """ declare data augmentation """ + normalize = transforms.Normalize( + mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225] + ) + + # 448:600 + # 384:510 + # 768: + if istrain: + # transforms.RandomApply([RandAugment(n=2, m=3, img_size=data_size)], p=0.1) + # RandAugment(n=2, m=3, img_size=sub_data_size) + self.transforms = transforms.Compose([ + transforms.Resize((510, 510), Image.BILINEAR), + transforms.RandomCrop((data_size, data_size)), + transforms.RandomHorizontalFlip(), + transforms.RandomApply([transforms.GaussianBlur(kernel_size=(5, 5), sigma=(0.1, 5))], p=0.1), + transforms.RandomAdjustSharpness(sharpness_factor=1.5, p=0.1), + transforms.ToTensor(), + normalize + ]) + else: + self.transforms = transforms.Compose([ + transforms.Resize((510, 510), Image.BILINEAR), + transforms.CenterCrop((data_size, data_size)), + transforms.ToTensor(), + normalize + ]) + + """ read all data information """ + self.data_infos = self.getDataInfo(root) + + def getDataInfo(self, root): + data_infos = [] + folders = os.listdir(root) + folders.sort() # 保证类别ID稳定 + print("[dataset] class number:", len(folders)) # 应输出400 + for class_id, folder in enumerate(folders): + # 用os.path.join拼接文件夹路径,自动处理斜杠 + folder_path = os.path.join(root, folder) + # 跳过非文件夹(避免误将文件当作类别文件夹) + if not os.path.isdir(folder_path): + continue + files = os.listdir(folder_path) + for file in files: + # 拼接图像完整路径 + data_path = os.path.join(folder_path, file) + data_infos.append({"path": data_path, "label": class_id}) + return data_infos + + def __len__(self): + return len(self.data_infos) + + def __getitem__(self, index): + image_path = self.data_infos[index]["path"] + label = self.data_infos[index]["label"] + + # 用健壮加载器读取图像(替换cv2.imread) + img = robust_pil_loader(image_path) # 自动处理损坏图像,返回RGB格式PIL图像 + img = self.transforms(img) # 直接应用增强(无需BGR→RGB转换,PIL默认RGB) + + if self.return_index: + return index, img, label + return img, label \ No newline at end of file diff --git a/data/__pycache__/__init__.cpython-38.pyc b/data/__pycache__/__init__.cpython-38.pyc index 2a94b30..47b3f32 100644 Binary files a/data/__pycache__/__init__.cpython-38.pyc and b/data/__pycache__/__init__.cpython-38.pyc differ diff --git a/data/__pycache__/dataset.cpython-38.pyc b/data/__pycache__/dataset.cpython-38.pyc index 55fb1c5..6a3027f 100644 Binary files a/data/__pycache__/dataset.cpython-38.pyc and b/data/__pycache__/dataset.cpython-38.pyc differ diff --git a/data/__pycache__/randaug.cpython-38.pyc b/data/__pycache__/randaug.cpython-38.pyc index ba85054..f349c9a 100644 Binary files a/data/__pycache__/randaug.cpython-38.pyc and b/data/__pycache__/randaug.cpython-38.pyc differ diff --git a/data/dataset.py b/data/dataset.py index c456617..de96aae 100644 --- a/data/dataset.py +++ b/data/dataset.py @@ -1,3 +1,4 @@ + import os import numpy as np import cv2 @@ -6,20 +7,52 @@ from PIL import Image import copy import torch - +import warnings +from PIL import ImageFile +warnings.filterwarnings("ignore", category=UserWarning, module="PIL.Image") +ImageFile.LOAD_TRUNCATED_IMAGES = True from .randaug import RandAugment def build_loader(args): - train_set, train_loader = None, None + train_set, train_loader, val_set, val_loader = None, None, None, None + + # 1. 加载完整的训练集(包含待划分的所有样本) if args.train_root is not None: - train_set = ImageDataset(istrain=True, root=args.train_root, data_size=args.data_size, return_index=True) - train_loader = torch.utils.data.DataLoader(train_set, num_workers=args.num_workers, shuffle=True, batch_size=args.batch_size) + full_train_set = ImageDataset(istrain=True, root=args.train_root, data_size=args.data_size, return_index=True) - val_set, val_loader = None, None - if args.val_root is not None: - val_set = ImageDataset(istrain=False, root=args.val_root, data_size=args.data_size, return_index=True) - val_loader = torch.utils.data.DataLoader(val_set, num_workers=1, shuffle=True, batch_size=args.batch_size) + # 2. 判断是否需要随机划分验证集 + if args.val_root is not None: + # 情况1:有独立的val文件夹,直接加载 + train_set = full_train_set + val_set = ImageDataset(istrain=False, root=args.val_root, data_size=args.data_size, return_index=True) + else: + # 情况2:无独立val文件夹,从训练集中随机划分(默认20%为验证集,可在yaml中加val_split参数) + val_size = int(len(full_train_set) * args.val_split) # args.val_split需在yaml中定义,如0.2 + train_size = len(full_train_set) - val_size + # 用random_split划分,固定种子确保每次划分一致 + train_set, val_set = torch.utils.data.random_split( + full_train_set, [train_size, val_size], + generator=torch.Generator().manual_seed(42) # 固定种子,可复现 + ) + # 关键:验证集需用“测试/验证模式”的增强(无随机裁剪、翻转) + val_set.dataset.istrain = False # 把验证集的增强模式改成istrain=False + val_set.dataset.transforms = transforms.Compose([ # 重新赋值验证集的增强逻辑 + transforms.Resize((510, 510), Image.BILINEAR), + transforms.CenterCrop((args.data_size, args.data_size)), + transforms.ToTensor(), + transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + ]) + + # 3. 构建DataLoader + if train_set is not None: + train_loader = torch.utils.data.DataLoader( + train_set, num_workers=args.num_workers, shuffle=True, batch_size=args.batch_size + ) + if val_set is not None: + val_loader = torch.utils.data.DataLoader( + val_set, num_workers=1, shuffle=False, batch_size=args.batch_size # 验证集不shuffle + ) return train_loader, val_loader @@ -29,6 +62,23 @@ def get_dataset(args): return train_set return None +# 定义健壮加载器(放在ImageDataset类外) +_CORRUPTED_WARNED = set() +def robust_pil_loader(path): + try: + with Image.open(path) as img: + if img.mode == 'P': + img = img.convert('RGBA') + if img.mode != 'RGB': + img = img.convert('RGB') + img.load() + return img.copy() + except Exception as e: + if path not in _CORRUPTED_WARNED: + print(f"Warning: Skipping corrupted image {path}. Error: {e}") + _CORRUPTED_WARNED.add(path) + # 返回灰色占位图(尺寸与data_size一致) + return Image.new('RGB', (data_size, data_size), color=(128, 128, 128)) class ImageDataset(torch.utils.data.Dataset): @@ -76,37 +126,35 @@ def __init__(self, """ read all data information """ self.data_infos = self.getDataInfo(root) - def getDataInfo(self, root): data_infos = [] folders = os.listdir(root) - folders.sort() # sort by alphabet - print("[dataset] class number:", len(folders)) + folders.sort() # 保证类别ID稳定 + print("[dataset] class number:", len(folders)) # 应输出400 for class_id, folder in enumerate(folders): - files = os.listdir(root+folder) + # 用os.path.join拼接文件夹路径,自动处理斜杠 + folder_path = os.path.join(root, folder) + # 跳过非文件夹(避免误将文件当作类别文件夹) + if not os.path.isdir(folder_path): + continue + files = os.listdir(folder_path) for file in files: - data_path = root+folder+"/"+file - data_infos.append({"path":data_path, "label":class_id}) + # 拼接图像完整路径 + data_path = os.path.join(folder_path, file) + data_infos.append({"path": data_path, "label": class_id}) return data_infos def __len__(self): return len(self.data_infos) def __getitem__(self, index): - # get data information. image_path = self.data_infos[index]["path"] label = self.data_infos[index]["label"] - # read image by opencv. - img = cv2.imread(image_path) - img = img[:, :, ::-1] # BGR to RGB. - - # to PIL.Image - img = Image.fromarray(img) - img = self.transforms(img) - + + # 用健壮加载器读取图像(替换cv2.imread) + img = robust_pil_loader(image_path) # 自动处理损坏图像,返回RGB格式PIL图像 + img = self.transforms(img) # 直接应用增强(无需BGR→RGB转换,PIL默认RGB) + if self.return_index: - # return index, img, sub_imgs, label, sub_boundarys return index, img, label - - # return img, sub_imgs, label, sub_boundarys - return img, label + return img, label \ No newline at end of file diff --git a/data/dataset_guanfang.py b/data/dataset_guanfang.py new file mode 100644 index 0000000..c456617 --- /dev/null +++ b/data/dataset_guanfang.py @@ -0,0 +1,112 @@ +import os +import numpy as np +import cv2 +import torch +import torchvision.transforms as transforms +from PIL import Image +import copy +import torch + +from .randaug import RandAugment + + +def build_loader(args): + train_set, train_loader = None, None + if args.train_root is not None: + train_set = ImageDataset(istrain=True, root=args.train_root, data_size=args.data_size, return_index=True) + train_loader = torch.utils.data.DataLoader(train_set, num_workers=args.num_workers, shuffle=True, batch_size=args.batch_size) + + val_set, val_loader = None, None + if args.val_root is not None: + val_set = ImageDataset(istrain=False, root=args.val_root, data_size=args.data_size, return_index=True) + val_loader = torch.utils.data.DataLoader(val_set, num_workers=1, shuffle=True, batch_size=args.batch_size) + + return train_loader, val_loader + +def get_dataset(args): + if args.train_root is not None: + train_set = ImageDataset(istrain=True, root=args.train_root, data_size=args.data_size, return_index=True) + return train_set + return None + + +class ImageDataset(torch.utils.data.Dataset): + + def __init__(self, + istrain: bool, + root: str, + data_size: int, + return_index: bool = False): + # notice that: + # sub_data_size mean sub-image's width and height. + """ basic information """ + self.root = root + self.data_size = data_size + self.return_index = return_index + + """ declare data augmentation """ + normalize = transforms.Normalize( + mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225] + ) + + # 448:600 + # 384:510 + # 768: + if istrain: + # transforms.RandomApply([RandAugment(n=2, m=3, img_size=data_size)], p=0.1) + # RandAugment(n=2, m=3, img_size=sub_data_size) + self.transforms = transforms.Compose([ + transforms.Resize((510, 510), Image.BILINEAR), + transforms.RandomCrop((data_size, data_size)), + transforms.RandomHorizontalFlip(), + transforms.RandomApply([transforms.GaussianBlur(kernel_size=(5, 5), sigma=(0.1, 5))], p=0.1), + transforms.RandomAdjustSharpness(sharpness_factor=1.5, p=0.1), + transforms.ToTensor(), + normalize + ]) + else: + self.transforms = transforms.Compose([ + transforms.Resize((510, 510), Image.BILINEAR), + transforms.CenterCrop((data_size, data_size)), + transforms.ToTensor(), + normalize + ]) + + """ read all data information """ + self.data_infos = self.getDataInfo(root) + + + def getDataInfo(self, root): + data_infos = [] + folders = os.listdir(root) + folders.sort() # sort by alphabet + print("[dataset] class number:", len(folders)) + for class_id, folder in enumerate(folders): + files = os.listdir(root+folder) + for file in files: + data_path = root+folder+"/"+file + data_infos.append({"path":data_path, "label":class_id}) + return data_infos + + def __len__(self): + return len(self.data_infos) + + def __getitem__(self, index): + # get data information. + image_path = self.data_infos[index]["path"] + label = self.data_infos[index]["label"] + # read image by opencv. + img = cv2.imread(image_path) + img = img[:, :, ::-1] # BGR to RGB. + + # to PIL.Image + img = Image.fromarray(img) + img = self.transforms(img) + + if self.return_index: + # return index, img, sub_imgs, label, sub_boundarys + return index, img, label + + # return img, sub_imgs, label, sub_boundarys + return img, label diff --git a/eval.py b/eval.py index 66c5d5c..7896f93 100644 --- a/eval.py +++ b/eval.py @@ -13,7 +13,7 @@ def cal_train_metrics(args, msg: dict, outs: dict, labels: torch.Tensor, batch_s """ only present top-1 training accuracy """ - + total_loss = 0.0 if args.use_fpn: diff --git a/infer_test.py b/infer_test.py new file mode 100644 index 0000000..08166a9 --- /dev/null +++ b/infer_test.py @@ -0,0 +1,218 @@ +import os +import torch +import pandas as pd +import torchvision.transforms as transforms +from PIL import Image, ImageFile +import warnings +import contextlib +from argparse import Namespace + +# -------------------------- 1. 导入项目核心模块(需确保路径正确)-------------------------- +# 若脚本不在项目根目录,需添加项目路径到Python环境(例如:sys.path.append("../")) +from models.builder import MODEL_GETTER +from utils.config_utils import load_yaml, get_args +from utils.costom_logger import timeLogger + +# -------------------------- 2. 工具函数:健壮图像加载器(处理损坏/截断图像)-------------------------- +warnings.filterwarnings("ignore", category=UserWarning, module="PIL.Image") +ImageFile.LOAD_TRUNCATED_IMAGES = True +_CORRUPTED_WARNED = set() + + +def robust_pil_loader(path, data_size): + """健壮的图像加载器:处理损坏图像,返回RGB格式PIL图像""" + try: + with Image.open(path) as img: + if img.mode == 'P': + img = img.convert('RGBA') + if img.mode != 'RGB': + img = img.convert('RGB') + img.load() + return img.copy() + except Exception as e: + if path not in _CORRUPTED_WARNED: + print(f"Warning: Skipping corrupted image {path}. Error: {e}") + _CORRUPTED_WARNED.add(path) + # 损坏图像返回灰色占位图(尺寸与模型输入一致) + return Image.new('RGB', (data_size, data_size), color=(128, 128, 128)) + + +# -------------------------- 3. 测试集Dataset(适配无标签单文件夹图像)-------------------------- +class TestDataset(torch.utils.data.Dataset): + def __init__(self, test_root: str, data_size: int): + self.test_root = test_root + self.data_size = data_size + # 测试集数据增强(与验证集一致,无随机操作,确保结果稳定) + self.transform = transforms.Compose([ + transforms.Resize((510, 510), Image.BILINEAR), # 与训练时的Resize一致 + transforms.CenterCrop((data_size, data_size)), # 固定中心裁剪 + transforms.ToTensor(), + transforms.Normalize(mean=[0.485, 0.456, 0.406], # ImageNet归一化(与训练一致) + std=[0.229, 0.224, 0.225]) + ]) + # 读取测试集所有图像的路径和文件名("id"为文件名) + self.image_infos = self._get_image_infos() + + def _get_image_infos(self): + """遍历测试集文件夹,收集图像路径和文件名""" + image_infos = [] + for filename in os.listdir(self.test_root): + file_path = os.path.join(self.test_root, filename) + if os.path.isfile(file_path): # 只保留文件(跳过子文件夹) + # "id"使用文件名(如"test_img_001.jpg"),后续直接存入CSV + image_infos.append({"path": file_path, "id": filename}) + print(f"[Test Dataset] Total images loaded: {len(image_infos)}") + return image_infos + + def __len__(self): + return len(self.image_infos) + + def __getitem__(self, index): + info = self.image_infos[index] + img_path = info["path"] + img_id = info["id"] # CSV的"id"列内容(文件名) + + # 加载图像并应用增强 + img = robust_pil_loader(img_path, self.data_size) + img_tensor = self.transform(img) + + return img_tensor, img_id # 返回:图像张量、图像ID(文件名) + + +# -------------------------- 4. 核心推理函数-------------------------- +def load_best_model(args, device): + """加载训练保存的best.pt模型""" + # 1. 构建与训练一致的模型结构 + model = MODEL_GETTER[args.model_name]( + use_fpn=args.use_fpn, + fpn_size=args.fpn_size, + use_selection=args.use_selection, + num_classes=args.num_classes, + num_selects=args.num_selects, + use_combiner=args.use_combiner + ) + # 2. 加载best.pt权重(路径:save_dir/backup/best.pt) + best_model_path = os.path.join(args.save_dir, "backup", "best.pt") +# best_model_path = os.path.join(args.save_dir, "backup", "last.pt") + if not os.path.exists(best_model_path): + raise FileNotFoundError(f"Best model not found! Path: {best_model_path}") + + checkpoint = torch.load(best_model_path, map_location=device) + model.load_state_dict(checkpoint["model"]) # 加载模型权重 + model.to(device) + model.eval() # 设为评估模式(关闭Dropout、BatchNorm固定) + print(f"Successfully loaded best model from: {best_model_path}") + return model + + +def get_class_names(train_root): + """获取类别名列表(与训练时的class_id对应,确保预测类别名正确)""" + # 训练时类别ID按文件夹排序生成,此处需保持一致 + class_names = sorted(os.listdir(train_root)) + print(f"[Class Info] Total classes: {len(class_names)}") + return class_names + + +def infer_test_set(args, device, tlogger): + """批量推理测试集,保存结果到CSV""" + # 1. 加载测试集 + tlogger.print("Loading test dataset...") + test_dataset = TestDataset( + test_root=args.test_root, + data_size=args.data_size + ) + test_loader = torch.utils.data.DataLoader( + test_dataset, + batch_size=args.batch_size, # 从配置文件取批量大小(适配GPU显存) + shuffle=False, # 测试集无需打乱 + num_workers=args.num_workers, # 从配置文件取线程数 + pin_memory=True # 加速数据传输到GPU + ) + + # 2. 加载best模型和类别名 + model = load_best_model(args, device) + class_names = get_class_names(args.train_root) # 从训练集路径获取类别名(确保ID对应) + + # 3. 初始化AMP上下文(与训练一致,若未启用则用nullcontext) + if args.use_amp: + amp_context = torch.cuda.amp.autocast + else: + amp_context = contextlib.nullcontext + + # 4. 批量推理 + tlogger.print("Start inferring test set...") + infer_results = [] # 存储最终结果:[{"id": "...", "class": "..."}] + total_batches = len(test_loader) + + with torch.no_grad(): # 关闭梯度计算(节省显存+加速推理) + for batch_idx, (img_tensors, img_ids) in enumerate(test_loader): + # 数据移到设备(GPU/CPU) + img_tensors = img_tensors.to(device) + + # 前向传播(获取模型输出) + with amp_context(): + outs = model(img_tensors) + + # 提取最终预测结果(根据PIM模型结构,优先取comb_outs;若无则取ori_out) + if "comb_outs" in outs: + pred_logits = outs["comb_outs"] # 组合器输出(训练时的主要预测) + elif "ori_out" in outs: + pred_logits = outs["ori_out"] # 原始输出(无PIM模块时) + else: + raise KeyError("Model output has no 'comb_outs' or 'ori_out'! Check model structure.") + + # 计算预测类别ID(取logits最大值对应的索引) + pred_class_ids = torch.argmax(pred_logits, dim=1).cpu().numpy() # 转CPU+Numpy + + # 映射类别ID到类别名,组装结果 + for img_id, pred_id in zip(img_ids, pred_class_ids): + pred_class_name = class_names[pred_id] # ID→类别名(如0→"Brewer_Blackbird") + infer_results.append({ + "id": img_id, # 列1:图像ID(文件名) + "class": pred_class_name # 列2:预测类别名 + }) + + # 打印推理进度 + if (batch_idx + 1) % args.log_freq == 0 or (batch_idx + 1) == total_batches: + progress = (batch_idx + 1) / total_batches * 100 + tlogger.print(f"Infer Progress: {progress:.1f}% (Batch {batch_idx + 1}/{total_batches})") + + # 5. 保存结果到CSV(路径:save_dir/test_infer_results.csv) + result_df = pd.DataFrame(infer_results) + csv_save_path = os.path.join(args.save_dir, "test_infer_results_adamw25.csv") + result_df.to_csv(csv_save_path, index=False, encoding="utf-8") # 不保存索引,UTF-8编码适配中文 + tlogger.print(f"Infer completed! Results saved to: {csv_save_path}") + + +# -------------------------- 5. 脚本入口(解析配置+启动推理)-------------------------- +if __name__ == "__main__": + # 初始化时间日志器(记录推理耗时) + tlogger = timeLogger() + tlogger.print("=" * 50) + tlogger.print("Starting Test Set Inference Script") + tlogger.print("=" * 50) + + # 1. 解析命令行参数(获取配置文件路径) + args = get_args() + # 断言:必须提供配置文件(yaml格式) + assert args.c != "", "Please provide config file via '-c your_config.yaml'!" + + # 2. 加载yaml配置文件(所有参数从配置文件读取,无需硬编码) + tlogger.print(f"Loading config file: {args.c}") + load_yaml(args, args.c) + + # 3. 检查关键配置(确保测试集路径和保存路径存在) + if not os.path.exists(args.test_root): + raise ValueError(f"Test set path not exist! Check 'test_root' in config: {args.test_root}") + if not os.path.exists(args.save_dir): + raise ValueError(f"Save directory not exist! Check 'save_dir' in config: {args.save_dir}") + + # 4. 设置设备(GPU优先,无GPU则用CPU) + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + tlogger.print(f"Using device: {device}") + + + # 5. 启动测试集推理 + infer_test_set(args, device, tlogger) + tlogger.print("Inference Script Finished!") + tlogger.print("=" * 50) \ No newline at end of file diff --git a/models/.ipynb_checkpoints/builder-checkpoint.py b/models/.ipynb_checkpoints/builder-checkpoint.py new file mode 100644 index 0000000..cb4dae5 --- /dev/null +++ b/models/.ipynb_checkpoints/builder-checkpoint.py @@ -0,0 +1,309 @@ +import torch +from typing import Union +from torchvision.models.feature_extraction import get_graph_node_names + +from .pim_module import pim_module + +""" +[Default Return] +Set return_nodes to None, you can use default return type, all of the model in this script +return four layers features. + +[Model Configuration] +if you are not using FPN module but using Selector and Combiner, you need to give Combiner a +projection dimension ('proj_size' of GCNCombiner in pim_module.py), because graph convolution +layer need the input features dimension be the same. + +[Combiner] +You must use selector so you can use combiner. + +[About Costom Model] +This function is to building swin transformer. timm swin-transformer + torch.fx.proxy.Proxy +could cause error, so we set return_nodes to None and change swin-transformer model script to +return features directly. +Please check 'timm/models/swin_transformer.py' line 541 to see how to change model if your costom +model also fail at create_feature_extractor or get_graph_node_names step. +""" + +def load_model_weights(model, model_path): + ### reference https://github.com/TACJu/TransFG + ### thanks a lot. + state = torch.load(model_path, map_location='cpu') + for key in model.state_dict(): + if 'num_batches_tracked' in key: + continue + p = model.state_dict()[key] + if key in state['state_dict']: + ip = state['state_dict'][key] + if p.shape == ip.shape: + p.data.copy_(ip.data) # Copy the data of parameters + else: + print('could not load layer: {}, mismatch shape {} ,{}'.format(key, (p.shape), (ip.shape))) + else: + print('could not load layer: {}, not in checkpoint'.format(key)) + return model + + +def build_resnet50(pretrained: str = "./resnet50_miil_21k.pth", + return_nodes: Union[dict, None] = None, + num_selects: Union[dict, None] = None, + img_size: int = 448, + use_fpn: bool = True, + fpn_size: int = 512, + proj_type: str = "Conv", + upsample_type: str = "Bilinear", + use_selection: bool = True, + num_classes: int = 200, + use_combiner: bool = True, + comb_proj_size: Union[int, None] = None): + + import timm + + if return_nodes is None: + return_nodes = { + 'layer1.2.act3': 'layer1', + 'layer2.3.act3': 'layer2', + 'layer3.5.act3': 'layer3', + 'layer4.2.act3': 'layer4', + } + if num_selects is None: + num_selects = { + 'layer1':32, + 'layer2':32, + 'layer3':32, + 'layer4':32 + } + + backbone = timm.create_model('resnet50', pretrained=False, num_classes=11221) + ### original pretrained path "./models/resnet50_miil_21k.pth" + if pretrained != "": + backbone = load_model_weights(backbone, pretrained) + + # print(backbone) + # print(get_graph_node_names(backbone)) + + return pim_module.PluginMoodel(backbone = backbone, + return_nodes = return_nodes, + img_size = img_size, + use_fpn = use_fpn, + fpn_size = fpn_size, + proj_type = proj_type, + upsample_type = upsample_type, + use_selection = use_selection, + num_classes = num_classes, + num_selects = num_selects, + use_combiner = num_selects, + comb_proj_size = comb_proj_size) + + +def build_efficientnet(pretrained: bool = True, + return_nodes: Union[dict, None] = None, + num_selects: Union[dict, None] = None, + img_size: int = 448, + use_fpn: bool = True, + fpn_size: int = 512, + proj_type: str = "Conv", + upsample_type: str = "Bilinear", + use_selection: bool = True, + num_classes: int = 200, + use_combiner: bool = True, + comb_proj_size: Union[int, None] = None): + + import torchvision.models as models + + if return_nodes is None: + return_nodes = { + 'features.4': 'layer1', + 'features.5': 'layer2', + 'features.6': 'layer3', + 'features.7': 'layer4', + } + if num_selects is None: + num_selects = { + 'layer1':32, + 'layer2':32, + 'layer3':32, + 'layer4':32 + } + + backbone = models.efficientnet_b7(pretrained=pretrained) + backbone.train() + + # print(backbone) + # print(get_graph_node_names(backbone)) + ## features.1~features.7 + + return pim_module.PluginMoodel(backbone = backbone, + return_nodes = return_nodes, + img_size = img_size, + use_fpn = use_fpn, + fpn_size = fpn_size, + proj_type = proj_type, + upsample_type = upsample_type, + use_selection = use_selection, + num_classes = num_classes, + num_selects = num_selects, + use_combiner = num_selects, + comb_proj_size = comb_proj_size) + + + + +def build_vit16(pretrained: str = "./vit_base_patch16_224_miil_21k.pth", + return_nodes: Union[dict, None] = None, + num_selects: Union[dict, None] = None, + img_size: int = 384, + use_fpn: bool = True, + fpn_size: int = 512, + proj_type: str = "Linear", + upsample_type: str = "Conv", + use_selection: bool = True, + num_classes: int = 200, + use_combiner: bool = True, + comb_proj_size: Union[int, None] = None): + + import timm + + backbone = timm.create_model('vit_base_patch16_224_miil_in21k', pretrained=False) + ### original pretrained path "./models/vit_base_patch16_224_miil_21k.pth" + if pretrained != "": + backbone = load_model_weights(backbone, pretrained) + + backbone.train() + + # print(backbone) + # print(get_graph_node_names(backbone)) + # 0~11 under blocks + + if return_nodes is None: + return_nodes = { + 'blocks.8': 'layer1', + 'blocks.9': 'layer2', + 'blocks.10': 'layer3', + 'blocks.11': 'layer4', + } + if num_selects is None: + num_selects = { + 'layer1':32, + 'layer2':32, + 'layer3':32, + 'layer4':32 + } + + ### Vit model input can transform 224 to another, we use linear + ### thanks: https://github.com/TACJu/TransFG/blob/master/models/modeling.py + import math + from scipy import ndimage + + posemb_tok, posemb_grid = backbone.pos_embed[:, :1], backbone.pos_embed[0, 1:] + posemb_grid = posemb_grid.detach().numpy() + gs_old = int(math.sqrt(len(posemb_grid))) + gs_new = img_size//16 + posemb_grid = posemb_grid.reshape(gs_old, gs_old, -1) + zoom = (gs_new / gs_old, gs_new / gs_old, 1) + posemb_grid = ndimage.zoom(posemb_grid, zoom, order=1) + posemb_grid = posemb_grid.reshape(1, gs_new * gs_new, -1) + posemb_grid = torch.from_numpy(posemb_grid) + posemb = torch.cat([posemb_tok, posemb_grid], dim=1) + backbone.pos_embed = torch.nn.Parameter(posemb) + + return pim_module.PluginMoodel(backbone = backbone, + return_nodes = return_nodes, + img_size = img_size, + use_fpn = use_fpn, + fpn_size = fpn_size, + proj_type = proj_type, + upsample_type = upsample_type, + use_selection = use_selection, + num_classes = num_classes, + num_selects = num_selects, + use_combiner = num_selects, + comb_proj_size = comb_proj_size) + + +def build_swintransformer(pretrained: bool = True, + num_selects: Union[dict, None] = None, + img_size: int = 384, + use_fpn: bool = True, + fpn_size: int = 512, + proj_type: str = "Linear", + upsample_type: str = "Conv", + use_selection: bool = True, + num_classes: int = 200, + use_combiner: bool = True, + comb_proj_size: Union[int, None] = None): + """ + This function is to building swin transformer. timm swin-transformer + torch.fx.proxy.Proxy + could cause error, so we set return_nodes to None and change swin-transformer model script to + return features directly. + Please check 'timm/models/swin_transformer.py' line 541 to see how to change model if your costom + model also fail at create_feature_extractor or get_graph_node_names step. + """ + + import timm + + if num_selects is None: + num_selects = { + 'layer1':32, + 'layer2':32, + 'layer3':32, + 'layer4':32 + } + + backbone = timm.create_model('swin_large_patch4_window12_384_in22k', pretrained=pretrained) + + # print(backbone) + # print(get_graph_node_names(backbone)) + backbone.train() + + print("Building...") + return pim_module.PluginMoodel(backbone = backbone, + return_nodes = None, + img_size = img_size, + use_fpn = use_fpn, + fpn_size = fpn_size, + proj_type = proj_type, + upsample_type = upsample_type, + use_selection = use_selection, + num_classes = num_classes, + num_selects = num_selects, + use_combiner = num_selects, + comb_proj_size = comb_proj_size) + + + + +if __name__ == "__main__": + ### ==== resnet50 ==== + # model = build_resnet50(pretrained='./resnet50_miil_21k.pth') + # t = torch.randn(1, 3, 448, 448) + + ### ==== swin-t ==== + # model = build_swintransformer(False) + # t = torch.randn(1, 3, 384, 384) + + ### ==== vit ==== + # model = build_vit16(pretrained='./vit_base_patch16_224_miil_21k.pth') + # t = torch.randn(1, 3, 448, 448) + + ### ==== efficientNet ==== + model = build_efficientnet(pretrained=False) + t = torch.randn(1, 3, 448, 448) + + model.cuda() + + t = t.cuda() + outs = model(t) + for out in outs: + print(type(out)) + print(" " , end="") + if type(out) == dict: + print([name for name in out]) + + +MODEL_GETTER = { + "resnet50":build_resnet50, + "swin-t":build_swintransformer, + "vit":build_vit16, + "efficient":build_efficientnet +} diff --git a/models/__pycache__/__init__.cpython-38.pyc b/models/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000..8788019 Binary files /dev/null and b/models/__pycache__/__init__.cpython-38.pyc differ diff --git a/models/__pycache__/builder.cpython-38.pyc b/models/__pycache__/builder.cpython-38.pyc new file mode 100644 index 0000000..b7c9323 Binary files /dev/null and b/models/__pycache__/builder.cpython-38.pyc differ diff --git a/models/builder.py b/models/builder.py index cb56d3c..2ef0c0b 100644 --- a/models/builder.py +++ b/models/builder.py @@ -152,7 +152,7 @@ def build_efficientnet(pretrained: bool = True, def build_vit16(pretrained: str = "./vit_base_patch16_224_miil_21k.pth", return_nodes: Union[dict, None] = None, num_selects: Union[dict, None] = None, - img_size: int = 448, + img_size: int = 384, use_fpn: bool = True, fpn_size: int = 512, proj_type: str = "Linear", @@ -217,7 +217,7 @@ def build_vit16(pretrained: str = "./vit_base_patch16_224_miil_21k.pth", use_selection = use_selection, num_classes = num_classes, num_selects = num_selects, - use_combiner = num_selects, + use_combiner = use_combiner, comb_proj_size = comb_proj_size) diff --git a/models/pim_module/.ipynb_checkpoints/pim_module-checkpoint.py b/models/pim_module/.ipynb_checkpoints/pim_module-checkpoint.py new file mode 100644 index 0000000..9f7601f --- /dev/null +++ b/models/pim_module/.ipynb_checkpoints/pim_module-checkpoint.py @@ -0,0 +1,433 @@ +import torch +import torch.nn as nn +import torchvision.models as models +import torch.nn.functional as F +from torchvision.models.feature_extraction import get_graph_node_names +from torchvision.models.feature_extraction import create_feature_extractor +from typing import Union +import copy + +class GCNCombiner(nn.Module): + + def __init__(self, + total_num_selects: int, + num_classes: int, + inputs: Union[dict, None] = None, + proj_size: Union[int, None] = None, + fpn_size: Union[int, None] = None): + """ + If building backbone without FPN, set fpn_size to None and MUST give + 'inputs' and 'proj_size', the reason of these setting is to constrain the + dimension of graph convolutional network input. + """ + super(GCNCombiner, self).__init__() + + assert inputs is not None or fpn_size is not None, \ + "To build GCN combiner, you must give one features dimension." + + ### auto-proj + self.fpn_size = fpn_size + if fpn_size is None: + for name in inputs: + if len(name) == 4: + in_size = inputs[name].size(1) + elif len(name) == 3: + in_size = inputs[name].size(2) + else: + raise ValusError("The size of output dimension of previous must be 3 or 4.") + m = nn.Sequential( + nn.Linear(in_size, proj_size), + nn.ReLU(), + nn.Linear(proj_size, proj_size) + ) + self.add_module("proj_"+name, m) + self.proj_size = proj_size + else: + self.proj_size = fpn_size + + ### build one layer structure (with adaptive module) + num_joints = total_num_selects // 32 + + self.param_pool0 = nn.Linear(total_num_selects, num_joints) + + A = torch.eye(num_joints)/100 + 1/100 + self.adj1 = nn.Parameter(copy.deepcopy(A)) + self.conv1 = nn.Conv1d(self.proj_size, self.proj_size, 1) + self.batch_norm1 = nn.BatchNorm1d(self.proj_size) + + self.conv_q1 = nn.Conv1d(self.proj_size, self.proj_size//4, 1) + self.conv_k1 = nn.Conv1d(self.proj_size, self.proj_size//4, 1) + self.alpha1 = nn.Parameter(torch.zeros(1)) + + ### merge information + self.param_pool1 = nn.Linear(num_joints, 1) + + #### class predict + self.dropout = nn.Dropout(p=0.1) + self.classifier = nn.Linear(self.proj_size, num_classes) + + self.tanh = nn.Tanh() + + def forward(self, x): + """ + """ + hs = [] + for name in x: + if self.fpn_size is None: + hs.append(getattr(self, "proj_"+name)(x[name])) + else: + hs.append(x[name]) + hs = torch.cat(hs, dim=1).transpose(1, 2).contiguous() # B, S', C --> B, C, S + hs = self.param_pool0(hs) + ### adaptive adjacency + q1 = self.conv_q1(hs).mean(1) + k1 = self.conv_k1(hs).mean(1) + A1 = self.tanh(q1.unsqueeze(-1) - k1.unsqueeze(1)) + A1 = self.adj1 + A1 * self.alpha1 + ### graph convolution + hs = self.conv1(hs) + hs = torch.matmul(hs, A1) + hs = self.batch_norm1(hs) + ### predict + hs = self.param_pool1(hs) + hs = self.dropout(hs) + hs = hs.flatten(1) + hs = self.classifier(hs) + + return hs + +class WeaklySelector(nn.Module): + + def __init__(self, inputs: dict, num_classes: int, num_select: dict, fpn_size: Union[int, None] = None): + """ + inputs: dictionary contain torch.Tensors, which comes from backbone + [Tensor1(hidden feature1), Tensor2(hidden feature2)...] + Please note that if len(features.size) equal to 3, the order of dimension must be [B,S,C], + S mean the spatial domain, and if len(features.size) equal to 4, the order must be [B,C,H,W] + + """ + super(WeaklySelector, self).__init__() + + self.num_select = num_select + + self.fpn_size = fpn_size + ### build classifier + if self.fpn_size is None: + self.num_classes = num_classes + for name in inputs: + fs_size = inputs[name].size() + if len(fs_size) == 3: + in_size = fs_size[2] + elif len(fs_size) == 4: + in_size = fs_size[1] + m = nn.Linear(in_size, num_classes) + self.add_module("classifier_l_"+name, m) + + # def select(self, logits, l_name): + # """ + # logits: [B, S, num_classes] + # """ + # probs = torch.softmax(logits, dim=-1) + # scores, _ = torch.max(probs, dim=-1) + # _, ids = torch.sort(scores, -1, descending=True) + # sn = self.num_select[l_name] + # s_ids = ids[:, :sn] + # not_s_ids = ids[:, sn:] + # return s_ids.unsqueeze(-1), not_s_ids.unsqueeze(-1) + + def forward(self, x, logits=None): + """ + x : + dictionary contain the features maps which + come from your choosen layers. + size must be [B, HxW, C] ([B, S, C]) or [B, C, H, W]. + [B,C,H,W] will be transpose to [B, HxW, C] automatically. + """ + if self.fpn_size is None: + logits = {} + selections = {} + for name in x: + if len(x[name].size()) == 4: + B, C, H, W = x[name].size() + x[name] = x[name].view(B, C, H*W).permute(0, 2, 1).contiguous() + C = x[name].size(-1) + if self.fpn_size is None: + logits[name] = getattr(self, "classifier_l_"+name)(x[name]) + + probs = torch.softmax(logits[name], dim=-1) + selections[name] = [] + preds_1 = [] + preds_0 = [] + num_select = self.num_select[name] + for bi in range(logits[name].size(0)): + max_ids, _ = torch.max(probs[bi], dim=-1) + confs, ranks = torch.sort(max_ids, descending=True) + sf = x[name][bi][ranks[:num_select]] + nf = x[name][bi][ranks[num_select:]] # calculate + selections[name].append(sf) # [num_selected, C] + preds_1.append(logits[name][bi][ranks[:num_select]]) + preds_0.append(logits[name][bi][ranks[num_select:]]) + + selections[name] = torch.stack(selections[name]) + preds_1 = torch.stack(preds_1) + preds_0 = torch.stack(preds_0) + + logits["select_"+name] = preds_1 + logits["drop_"+name] = preds_0 + + return selections + + +class FPN(nn.Module): + + def __init__(self, inputs: dict, fpn_size: int, proj_type: str, upsample_type: str): + """ + inputs : dictionary contains torch.Tensor + which comes from backbone output + fpn_size: integer, fpn + proj_type: + in ["Conv", "Linear"] + upsample_type: + in ["Bilinear", "Conv", "Fc"] + for convolution neural network (e.g. ResNet, EfficientNet), recommand 'Bilinear'. + for Vit, "Fc". and Swin-T, "Conv" + """ + super(FPN, self).__init__() + assert proj_type in ["Conv", "Linear"], \ + "FPN projection type {} were not support yet, please choose type 'Conv' or 'Linear'".format(proj_type) + assert upsample_type in ["Bilinear", "Conv"], \ + "FPN upsample type {} were not support yet, please choose type 'Bilinear' or 'Conv'".format(proj_type) + + self.fpn_size = fpn_size + self.upsample_type = upsample_type + inp_names = [name for name in inputs] + + for i, node_name in enumerate(inputs): + ### projection module + if proj_type == "Conv": + m = nn.Sequential( + nn.Conv2d(inputs[node_name].size(1), inputs[node_name].size(1), 1), + nn.ReLU(), + nn.Conv2d(inputs[node_name].size(1), fpn_size, 1) + ) + elif proj_type == "Linear": + m = nn.Sequential( + nn.Linear(inputs[node_name].size(-1), inputs[node_name].size(-1)), + nn.ReLU(), + nn.Linear(inputs[node_name].size(-1), fpn_size), + ) + self.add_module("Proj_"+node_name, m) + + ### upsample module + if upsample_type == "Conv" and i != 0: + assert len(inputs[node_name].size()) == 3 # B, S, C + in_dim = inputs[node_name].size(1) + out_dim = inputs[inp_names[i-1]].size(1) + if in_dim != out_dim: + m = nn.Conv1d(in_dim, out_dim, 1) # for spatial domain + else: + m = nn.Identity() + self.add_module("Up_"+node_name, m) + + if upsample_type == "Bilinear": + self.upsample = nn.Upsample(scale_factor=2, mode='bilinear') + + def upsample_add(self, x0: torch.Tensor, x1: torch.Tensor, x1_name: str): + """ + return Upsample(x1) + x1 + """ + if self.upsample_type == "Bilinear": + if x1.size(-1) != x0.size(-1): + x1 = self.upsample(x1) + else: + x1 = getattr(self, "Up_"+x1_name)(x1) + return x1 + x0 + + def forward(self, x): + """ + x : dictionary + { + "node_name1": feature1, + "node_name2": feature2, ... + } + """ + ### project to same dimension + hs = [] + for i, name in enumerate(x): + x[name] = getattr(self, "Proj_"+name)(x[name]) + hs.append(name) + + for i in range(len(hs)-1, 0, -1): + x1_name = hs[i] + x0_name = hs[i-1] + x[x0_name] = self.upsample_add(x[x0_name], + x[x1_name], + x1_name) + return x + + +class PluginMoodel(nn.Module): + + def __init__(self, + backbone: torch.nn.Module, + return_nodes: Union[dict, None], + img_size: int, + use_fpn: bool, + fpn_size: Union[int, None], + proj_type: str, + upsample_type: str, + use_selection: bool, + num_classes: int, + num_selects: dict, + use_combiner: bool, + comb_proj_size: Union[int, None] + ): + """ + * backbone: + torch.nn.Module class (recommand pretrained on ImageNet or IG-3.5B-17k(provided by FAIR)) + * return_nodes: + e.g. + return_nodes = { + # node_name: user-specified key for output dict + 'layer1.2.relu_2': 'layer1', + 'layer2.3.relu_2': 'layer2', + 'layer3.5.relu_2': 'layer3', + 'layer4.2.relu_2': 'layer4', + } # you can see the example on https://pytorch.org/vision/main/feature_extraction.html + !!! if using 'Swin-Transformer', please set return_nodes to None + !!! and please set use_fpn to True + * feat_sizes: + tuple or list contain features map size of each layers. + ((C, H, W)). e.g. ((1024, 14, 14), (2048, 7, 7)) + * use_fpn: + boolean, use features pyramid network or not + * fpn_size: + integer, features pyramid network projection dimension + * num_selects: + num_selects = { + # match user-specified in return_nodes + "layer1": 2048, + "layer2": 512, + "layer3": 128, + "layer4": 32, + } + + Note: after selector module (WeaklySelector) , the feature map's size is [B, S', C] which + contained by 'logits' or 'selections' dictionary (S' is selection number, different layer + could be different). + """ + super(PluginMoodel, self).__init__() + + ### = = = = = Backbone = = = = = + self.return_nodes = return_nodes + if return_nodes is not None: + self.backbone = create_feature_extractor(backbone, return_nodes=return_nodes) + else: + self.backbone = backbone + + ### get hidden feartues size + rand_in = torch.randn(1, 3, img_size, img_size) + outs = self.backbone(rand_in) + + ### just original backbone + if not use_fpn and (not use_selection and not use_combiner): + for name in outs: + fs_size = outs[name].size() + if len(fs_size) == 3: + out_size = fs_size.size(-1) + elif len(fs_size) == 4: + out_size = fs_size.size(1) + else: + raise ValusError("The size of output dimension of previous must be 3 or 4.") + self.classifier = nn.Linear(out_size, num_classes) + + ### = = = = = FPN = = = = = + self.use_fpn = use_fpn + if self.use_fpn: + self.fpn = FPN(outs, fpn_size, proj_type, upsample_type) + self.build_fpn_classifier(outs, fpn_size, num_classes) + + self.fpn_size = fpn_size + + ### = = = = = Selector = = = = = + self.use_selection = use_selection + if self.use_selection: + w_fpn_size = self.fpn_size if self.use_fpn else None # if not using fpn, build classifier in weakly selector + self.selector = WeaklySelector(outs, num_classes, num_selects, w_fpn_size) + + ### = = = = = Combiner = = = = = + self.use_combiner = use_combiner + if self.use_combiner: + assert self.use_selection, "Please use selection module before combiner" + if self.use_fpn: + gcn_inputs, gcn_proj_size = None, None + else: + gcn_inputs, gcn_proj_size = outs, comb_proj_size # redundant, fix in future + total_num_selects = sum([num_selects[name] for name in num_selects]) # sum + self.combiner = GCNCombiner(total_num_selects, num_classes, gcn_inputs, gcn_proj_size, self.fpn_size) + + def build_fpn_classifier(self, inputs: dict, fpn_size: int, num_classes: int): + """ + Teh results of our experiments show that linear classifier in this case may cause some problem. + """ + for name in inputs: + m = nn.Sequential( + nn.Conv1d(fpn_size, fpn_size, 1), + nn.BatchNorm1d(fpn_size), + nn.ReLU(), + nn.Conv1d(fpn_size, num_classes, 1) + ) + self.add_module("fpn_classifier_"+name, m) + + def forward_backbone(self, x): + return self.backbone(x) + + def fpn_predict(self, x: dict, logits: dict): + """ + x: [B, C, H, W] or [B, S, C] + [B, C, H, W] --> [B, H*W, C] + """ + for name in x: + ### predict on each features point + if len(x[name].size()) == 4: + B, C, H, W = x[name].size() + logit = x[name].view(B, C, H*W) + elif len(x[name].size()) == 3: + logit = x[name].transpose(1, 2).contiguous() + logits[name] = getattr(self, "fpn_classifier_"+name)(logit) + logits[name] = logits[name].transpose(1, 2).contiguous() # transpose + + def forward(self, x: torch.Tensor): + + logits = {} + + x = self.forward_backbone(x) + + if self.use_fpn: + x = self.fpn(x) + self.fpn_predict(x, logits) + + if self.use_selection: + selects = self.selector(x, logits) + + if self.use_combiner: + comb_outs = self.combiner(selects) + logits['comb_outs'] = comb_outs + return logits + + if self.use_selection or self.fpn: + return logits + + ### original backbone (only predict final selected layer) + for name in x: + hs = x[name] + + if len(hs.size()) == 4: + hs = F.adaptive_avg_pool2d(hs, (1, 1)) + hs = hs.flatten(1) + else: + hs = hs.mean(1) + out = self.classifier(hs) + logits['ori_out'] = logits + + return logits diff --git a/models/pim_module/__pycache__/pim_module.cpython-38.pyc b/models/pim_module/__pycache__/pim_module.cpython-38.pyc index 8a78de4..ad3c11a 100644 Binary files a/models/pim_module/__pycache__/pim_module.cpython-38.pyc and b/models/pim_module/__pycache__/pim_module.cpython-38.pyc differ diff --git a/order.txt b/order.txt new file mode 100644 index 0000000..804b524 --- /dev/null +++ b/order.txt @@ -0,0 +1,6 @@ +python vit_pim_main.py --c ./configs/fgvc_vit_smuaic.yaml +python infer_test.py --c ./configs/fgvc_vit_smuaic.yaml +python vit_pim_main_sgd.py --c ./configs/fgvc_vit_smuaic-strong.yaml +python infer_test.py --c ./configs/fgvc_vit_smuaic-strong.yaml + +python SwinT_pim_main_sgd.py --c ./configs/AIC_PIM_SwinT.yaml \ No newline at end of file diff --git a/timm/__pycache__/__init__.cpython-38.pyc b/timm/__pycache__/__init__.cpython-38.pyc index 7b1de3b..33108c3 100644 Binary files a/timm/__pycache__/__init__.cpython-38.pyc and b/timm/__pycache__/__init__.cpython-38.pyc differ diff --git a/timm/__pycache__/version.cpython-38.pyc b/timm/__pycache__/version.cpython-38.pyc index 596125e..8bcdc8a 100644 Binary files a/timm/__pycache__/version.cpython-38.pyc and b/timm/__pycache__/version.cpython-38.pyc differ diff --git a/timm/data/__pycache__/__init__.cpython-38.pyc b/timm/data/__pycache__/__init__.cpython-38.pyc index be7c681..6ee444c 100644 Binary files a/timm/data/__pycache__/__init__.cpython-38.pyc and b/timm/data/__pycache__/__init__.cpython-38.pyc differ diff --git a/timm/data/__pycache__/auto_augment.cpython-38.pyc b/timm/data/__pycache__/auto_augment.cpython-38.pyc index 5aea7df..9369a30 100644 Binary files a/timm/data/__pycache__/auto_augment.cpython-38.pyc and b/timm/data/__pycache__/auto_augment.cpython-38.pyc differ diff --git a/timm/data/__pycache__/config.cpython-38.pyc b/timm/data/__pycache__/config.cpython-38.pyc index cfa917d..1413507 100644 Binary files a/timm/data/__pycache__/config.cpython-38.pyc and b/timm/data/__pycache__/config.cpython-38.pyc differ diff --git a/timm/data/__pycache__/constants.cpython-38.pyc b/timm/data/__pycache__/constants.cpython-38.pyc index 29a3a94..8898e11 100644 Binary files a/timm/data/__pycache__/constants.cpython-38.pyc and b/timm/data/__pycache__/constants.cpython-38.pyc differ diff --git a/timm/data/__pycache__/dataset.cpython-38.pyc b/timm/data/__pycache__/dataset.cpython-38.pyc index d244459..72b035d 100644 Binary files a/timm/data/__pycache__/dataset.cpython-38.pyc and b/timm/data/__pycache__/dataset.cpython-38.pyc differ diff --git a/timm/data/__pycache__/dataset_factory.cpython-38.pyc b/timm/data/__pycache__/dataset_factory.cpython-38.pyc index bb2baa1..4725b51 100644 Binary files a/timm/data/__pycache__/dataset_factory.cpython-38.pyc and b/timm/data/__pycache__/dataset_factory.cpython-38.pyc differ diff --git a/timm/data/__pycache__/distributed_sampler.cpython-38.pyc b/timm/data/__pycache__/distributed_sampler.cpython-38.pyc index e40b5a1..6ec44e4 100644 Binary files a/timm/data/__pycache__/distributed_sampler.cpython-38.pyc and b/timm/data/__pycache__/distributed_sampler.cpython-38.pyc differ diff --git a/timm/data/__pycache__/loader.cpython-38.pyc b/timm/data/__pycache__/loader.cpython-38.pyc index 8dcd68c..8957129 100644 Binary files a/timm/data/__pycache__/loader.cpython-38.pyc and b/timm/data/__pycache__/loader.cpython-38.pyc differ diff --git a/timm/data/__pycache__/mixup.cpython-38.pyc b/timm/data/__pycache__/mixup.cpython-38.pyc index ba8a886..aab698c 100644 Binary files a/timm/data/__pycache__/mixup.cpython-38.pyc and b/timm/data/__pycache__/mixup.cpython-38.pyc differ diff --git a/timm/data/__pycache__/random_erasing.cpython-38.pyc b/timm/data/__pycache__/random_erasing.cpython-38.pyc index 4c99ccf..b8174e5 100644 Binary files a/timm/data/__pycache__/random_erasing.cpython-38.pyc and b/timm/data/__pycache__/random_erasing.cpython-38.pyc differ diff --git a/timm/data/__pycache__/real_labels.cpython-38.pyc b/timm/data/__pycache__/real_labels.cpython-38.pyc index e8e66f8..25a6f84 100644 Binary files a/timm/data/__pycache__/real_labels.cpython-38.pyc and b/timm/data/__pycache__/real_labels.cpython-38.pyc differ diff --git a/timm/data/__pycache__/transforms.cpython-38.pyc b/timm/data/__pycache__/transforms.cpython-38.pyc index eab8ed3..65a2f66 100644 Binary files a/timm/data/__pycache__/transforms.cpython-38.pyc and b/timm/data/__pycache__/transforms.cpython-38.pyc differ diff --git a/timm/data/__pycache__/transforms_factory.cpython-38.pyc b/timm/data/__pycache__/transforms_factory.cpython-38.pyc index e32ecce..c3d8071 100644 Binary files a/timm/data/__pycache__/transforms_factory.cpython-38.pyc and b/timm/data/__pycache__/transforms_factory.cpython-38.pyc differ diff --git a/timm/data/parsers/__pycache__/__init__.cpython-38.pyc b/timm/data/parsers/__pycache__/__init__.cpython-38.pyc index 37d60ab..b8fe61e 100644 Binary files a/timm/data/parsers/__pycache__/__init__.cpython-38.pyc and b/timm/data/parsers/__pycache__/__init__.cpython-38.pyc differ diff --git a/timm/data/parsers/__pycache__/class_map.cpython-38.pyc b/timm/data/parsers/__pycache__/class_map.cpython-38.pyc index dc4bb66..0ff947a 100644 Binary files a/timm/data/parsers/__pycache__/class_map.cpython-38.pyc and b/timm/data/parsers/__pycache__/class_map.cpython-38.pyc differ diff --git a/timm/data/parsers/__pycache__/constants.cpython-38.pyc b/timm/data/parsers/__pycache__/constants.cpython-38.pyc index 5f02e73..14dd07d 100644 Binary files a/timm/data/parsers/__pycache__/constants.cpython-38.pyc and b/timm/data/parsers/__pycache__/constants.cpython-38.pyc differ diff --git a/timm/data/parsers/__pycache__/parser.cpython-38.pyc b/timm/data/parsers/__pycache__/parser.cpython-38.pyc index f7e2ea5..23ad6c2 100644 Binary files a/timm/data/parsers/__pycache__/parser.cpython-38.pyc and b/timm/data/parsers/__pycache__/parser.cpython-38.pyc differ diff --git a/timm/data/parsers/__pycache__/parser_factory.cpython-38.pyc b/timm/data/parsers/__pycache__/parser_factory.cpython-38.pyc index 1c80feb..b9fbf96 100644 Binary files a/timm/data/parsers/__pycache__/parser_factory.cpython-38.pyc and b/timm/data/parsers/__pycache__/parser_factory.cpython-38.pyc differ diff --git a/timm/data/parsers/__pycache__/parser_image_folder.cpython-38.pyc b/timm/data/parsers/__pycache__/parser_image_folder.cpython-38.pyc index cb54ece..4634b5b 100644 Binary files a/timm/data/parsers/__pycache__/parser_image_folder.cpython-38.pyc and b/timm/data/parsers/__pycache__/parser_image_folder.cpython-38.pyc differ diff --git a/timm/data/parsers/__pycache__/parser_image_in_tar.cpython-38.pyc b/timm/data/parsers/__pycache__/parser_image_in_tar.cpython-38.pyc index 635aa83..2dad453 100644 Binary files a/timm/data/parsers/__pycache__/parser_image_in_tar.cpython-38.pyc and b/timm/data/parsers/__pycache__/parser_image_in_tar.cpython-38.pyc differ diff --git a/timm/data/parsers/__pycache__/parser_image_tar.cpython-38.pyc b/timm/data/parsers/__pycache__/parser_image_tar.cpython-38.pyc index 5b5daa9..ec89a05 100644 Binary files a/timm/data/parsers/__pycache__/parser_image_tar.cpython-38.pyc and b/timm/data/parsers/__pycache__/parser_image_tar.cpython-38.pyc differ diff --git a/timm/models/__pycache__/__init__.cpython-38.pyc b/timm/models/__pycache__/__init__.cpython-38.pyc index fccb635..1edeb40 100644 Binary files a/timm/models/__pycache__/__init__.cpython-38.pyc and b/timm/models/__pycache__/__init__.cpython-38.pyc differ diff --git a/timm/models/__pycache__/byoanet.cpython-38.pyc b/timm/models/__pycache__/byoanet.cpython-38.pyc index bcc610e..47bcf84 100644 Binary files a/timm/models/__pycache__/byoanet.cpython-38.pyc and b/timm/models/__pycache__/byoanet.cpython-38.pyc differ diff --git a/timm/models/__pycache__/byobnet.cpython-38.pyc b/timm/models/__pycache__/byobnet.cpython-38.pyc index 9ddfa49..d9f91a0 100644 Binary files a/timm/models/__pycache__/byobnet.cpython-38.pyc and b/timm/models/__pycache__/byobnet.cpython-38.pyc differ diff --git a/timm/models/__pycache__/cait.cpython-38.pyc b/timm/models/__pycache__/cait.cpython-38.pyc index e107380..8682f12 100644 Binary files a/timm/models/__pycache__/cait.cpython-38.pyc and b/timm/models/__pycache__/cait.cpython-38.pyc differ diff --git a/timm/models/__pycache__/coat.cpython-38.pyc b/timm/models/__pycache__/coat.cpython-38.pyc index fec2cf7..4df6de0 100644 Binary files a/timm/models/__pycache__/coat.cpython-38.pyc and b/timm/models/__pycache__/coat.cpython-38.pyc differ diff --git a/timm/models/__pycache__/convit.cpython-38.pyc b/timm/models/__pycache__/convit.cpython-38.pyc index d854f1c..3401ac9 100644 Binary files a/timm/models/__pycache__/convit.cpython-38.pyc and b/timm/models/__pycache__/convit.cpython-38.pyc differ diff --git a/timm/models/__pycache__/cspnet.cpython-38.pyc b/timm/models/__pycache__/cspnet.cpython-38.pyc index ee3aa81..f8996aa 100644 Binary files a/timm/models/__pycache__/cspnet.cpython-38.pyc and b/timm/models/__pycache__/cspnet.cpython-38.pyc differ diff --git a/timm/models/__pycache__/densenet.cpython-38.pyc b/timm/models/__pycache__/densenet.cpython-38.pyc index a4ab156..c01d51a 100644 Binary files a/timm/models/__pycache__/densenet.cpython-38.pyc and b/timm/models/__pycache__/densenet.cpython-38.pyc differ diff --git a/timm/models/__pycache__/dla.cpython-38.pyc b/timm/models/__pycache__/dla.cpython-38.pyc index cf663cf..cffe737 100644 Binary files a/timm/models/__pycache__/dla.cpython-38.pyc and b/timm/models/__pycache__/dla.cpython-38.pyc differ diff --git a/timm/models/__pycache__/dpn.cpython-38.pyc b/timm/models/__pycache__/dpn.cpython-38.pyc index d1eee55..e882c7c 100644 Binary files a/timm/models/__pycache__/dpn.cpython-38.pyc and b/timm/models/__pycache__/dpn.cpython-38.pyc differ diff --git a/timm/models/__pycache__/efficientnet.cpython-38.pyc b/timm/models/__pycache__/efficientnet.cpython-38.pyc index d64c48e..b2aa595 100644 Binary files a/timm/models/__pycache__/efficientnet.cpython-38.pyc and b/timm/models/__pycache__/efficientnet.cpython-38.pyc differ diff --git a/timm/models/__pycache__/efficientnet_blocks.cpython-38.pyc b/timm/models/__pycache__/efficientnet_blocks.cpython-38.pyc index 7e91c0d..cb0c7a9 100644 Binary files a/timm/models/__pycache__/efficientnet_blocks.cpython-38.pyc and b/timm/models/__pycache__/efficientnet_blocks.cpython-38.pyc differ diff --git a/timm/models/__pycache__/efficientnet_builder.cpython-38.pyc b/timm/models/__pycache__/efficientnet_builder.cpython-38.pyc index 8da5d1d..c509176 100644 Binary files a/timm/models/__pycache__/efficientnet_builder.cpython-38.pyc and b/timm/models/__pycache__/efficientnet_builder.cpython-38.pyc differ diff --git a/timm/models/__pycache__/factory.cpython-38.pyc b/timm/models/__pycache__/factory.cpython-38.pyc index bb1d016..8083652 100644 Binary files a/timm/models/__pycache__/factory.cpython-38.pyc and b/timm/models/__pycache__/factory.cpython-38.pyc differ diff --git a/timm/models/__pycache__/features.cpython-38.pyc b/timm/models/__pycache__/features.cpython-38.pyc index c2a4dcd..88ed407 100644 Binary files a/timm/models/__pycache__/features.cpython-38.pyc and b/timm/models/__pycache__/features.cpython-38.pyc differ diff --git a/timm/models/__pycache__/ghostnet.cpython-38.pyc b/timm/models/__pycache__/ghostnet.cpython-38.pyc index 7e5e9a5..613bf8c 100644 Binary files a/timm/models/__pycache__/ghostnet.cpython-38.pyc and b/timm/models/__pycache__/ghostnet.cpython-38.pyc differ diff --git a/timm/models/__pycache__/gluon_resnet.cpython-38.pyc b/timm/models/__pycache__/gluon_resnet.cpython-38.pyc index f2ac1a4..ca86d54 100644 Binary files a/timm/models/__pycache__/gluon_resnet.cpython-38.pyc and b/timm/models/__pycache__/gluon_resnet.cpython-38.pyc differ diff --git a/timm/models/__pycache__/gluon_xception.cpython-38.pyc b/timm/models/__pycache__/gluon_xception.cpython-38.pyc index 514ccd2..965dfeb 100644 Binary files a/timm/models/__pycache__/gluon_xception.cpython-38.pyc and b/timm/models/__pycache__/gluon_xception.cpython-38.pyc differ diff --git a/timm/models/__pycache__/hardcorenas.cpython-38.pyc b/timm/models/__pycache__/hardcorenas.cpython-38.pyc index d26fb3e..3c0ea74 100644 Binary files a/timm/models/__pycache__/hardcorenas.cpython-38.pyc and b/timm/models/__pycache__/hardcorenas.cpython-38.pyc differ diff --git a/timm/models/__pycache__/helpers.cpython-38.pyc b/timm/models/__pycache__/helpers.cpython-38.pyc index 288e375..682efd6 100644 Binary files a/timm/models/__pycache__/helpers.cpython-38.pyc and b/timm/models/__pycache__/helpers.cpython-38.pyc differ diff --git a/timm/models/__pycache__/hrnet.cpython-38.pyc b/timm/models/__pycache__/hrnet.cpython-38.pyc index c1108ac..f11957e 100644 Binary files a/timm/models/__pycache__/hrnet.cpython-38.pyc and b/timm/models/__pycache__/hrnet.cpython-38.pyc differ diff --git a/timm/models/__pycache__/hub.cpython-38.pyc b/timm/models/__pycache__/hub.cpython-38.pyc index d39eea7..ad1fdcd 100644 Binary files a/timm/models/__pycache__/hub.cpython-38.pyc and b/timm/models/__pycache__/hub.cpython-38.pyc differ diff --git a/timm/models/__pycache__/inception_resnet_v2.cpython-38.pyc b/timm/models/__pycache__/inception_resnet_v2.cpython-38.pyc index ce36533..bd19323 100644 Binary files a/timm/models/__pycache__/inception_resnet_v2.cpython-38.pyc and b/timm/models/__pycache__/inception_resnet_v2.cpython-38.pyc differ diff --git a/timm/models/__pycache__/inception_v3.cpython-38.pyc b/timm/models/__pycache__/inception_v3.cpython-38.pyc index 87ddba3..9649357 100644 Binary files a/timm/models/__pycache__/inception_v3.cpython-38.pyc and b/timm/models/__pycache__/inception_v3.cpython-38.pyc differ diff --git a/timm/models/__pycache__/inception_v4.cpython-38.pyc b/timm/models/__pycache__/inception_v4.cpython-38.pyc index aceb7a8..7d352e0 100644 Binary files a/timm/models/__pycache__/inception_v4.cpython-38.pyc and b/timm/models/__pycache__/inception_v4.cpython-38.pyc differ diff --git a/timm/models/__pycache__/levit.cpython-38.pyc b/timm/models/__pycache__/levit.cpython-38.pyc index ff5a0a4..39dae50 100644 Binary files a/timm/models/__pycache__/levit.cpython-38.pyc and b/timm/models/__pycache__/levit.cpython-38.pyc differ diff --git a/timm/models/__pycache__/mlp_mixer.cpython-38.pyc b/timm/models/__pycache__/mlp_mixer.cpython-38.pyc index eb38849..29c0d0f 100644 Binary files a/timm/models/__pycache__/mlp_mixer.cpython-38.pyc and b/timm/models/__pycache__/mlp_mixer.cpython-38.pyc differ diff --git a/timm/models/__pycache__/mobilenetv3.cpython-38.pyc b/timm/models/__pycache__/mobilenetv3.cpython-38.pyc index ab20d88..fb6ace1 100644 Binary files a/timm/models/__pycache__/mobilenetv3.cpython-38.pyc and b/timm/models/__pycache__/mobilenetv3.cpython-38.pyc differ diff --git a/timm/models/__pycache__/nasnet.cpython-38.pyc b/timm/models/__pycache__/nasnet.cpython-38.pyc index cdd554a..9f3ec82 100644 Binary files a/timm/models/__pycache__/nasnet.cpython-38.pyc and b/timm/models/__pycache__/nasnet.cpython-38.pyc differ diff --git a/timm/models/__pycache__/nfnet.cpython-38.pyc b/timm/models/__pycache__/nfnet.cpython-38.pyc index c2ca505..69d9f36 100644 Binary files a/timm/models/__pycache__/nfnet.cpython-38.pyc and b/timm/models/__pycache__/nfnet.cpython-38.pyc differ diff --git a/timm/models/__pycache__/pit.cpython-38.pyc b/timm/models/__pycache__/pit.cpython-38.pyc index f726880..fc97b8f 100644 Binary files a/timm/models/__pycache__/pit.cpython-38.pyc and b/timm/models/__pycache__/pit.cpython-38.pyc differ diff --git a/timm/models/__pycache__/pnasnet.cpython-38.pyc b/timm/models/__pycache__/pnasnet.cpython-38.pyc index 2878dd9..5ef9348 100644 Binary files a/timm/models/__pycache__/pnasnet.cpython-38.pyc and b/timm/models/__pycache__/pnasnet.cpython-38.pyc differ diff --git a/timm/models/__pycache__/registry.cpython-38.pyc b/timm/models/__pycache__/registry.cpython-38.pyc index 0b1a431..7b3f1e0 100644 Binary files a/timm/models/__pycache__/registry.cpython-38.pyc and b/timm/models/__pycache__/registry.cpython-38.pyc differ diff --git a/timm/models/__pycache__/regnet.cpython-38.pyc b/timm/models/__pycache__/regnet.cpython-38.pyc index 56d4136..b392481 100644 Binary files a/timm/models/__pycache__/regnet.cpython-38.pyc and b/timm/models/__pycache__/regnet.cpython-38.pyc differ diff --git a/timm/models/__pycache__/res2net.cpython-38.pyc b/timm/models/__pycache__/res2net.cpython-38.pyc index 758ffb5..6f36e38 100644 Binary files a/timm/models/__pycache__/res2net.cpython-38.pyc and b/timm/models/__pycache__/res2net.cpython-38.pyc differ diff --git a/timm/models/__pycache__/resnest.cpython-38.pyc b/timm/models/__pycache__/resnest.cpython-38.pyc index 73a9737..1e9fc76 100644 Binary files a/timm/models/__pycache__/resnest.cpython-38.pyc and b/timm/models/__pycache__/resnest.cpython-38.pyc differ diff --git a/timm/models/__pycache__/resnet.cpython-38.pyc b/timm/models/__pycache__/resnet.cpython-38.pyc index 7378c35..a8300a4 100644 Binary files a/timm/models/__pycache__/resnet.cpython-38.pyc and b/timm/models/__pycache__/resnet.cpython-38.pyc differ diff --git a/timm/models/__pycache__/resnetv2.cpython-38.pyc b/timm/models/__pycache__/resnetv2.cpython-38.pyc index 9b7ec94..4d980db 100644 Binary files a/timm/models/__pycache__/resnetv2.cpython-38.pyc and b/timm/models/__pycache__/resnetv2.cpython-38.pyc differ diff --git a/timm/models/__pycache__/rexnet.cpython-38.pyc b/timm/models/__pycache__/rexnet.cpython-38.pyc index c675bce..46f9448 100644 Binary files a/timm/models/__pycache__/rexnet.cpython-38.pyc and b/timm/models/__pycache__/rexnet.cpython-38.pyc differ diff --git a/timm/models/__pycache__/selecsls.cpython-38.pyc b/timm/models/__pycache__/selecsls.cpython-38.pyc index 77acd93..8c6f9a5 100644 Binary files a/timm/models/__pycache__/selecsls.cpython-38.pyc and b/timm/models/__pycache__/selecsls.cpython-38.pyc differ diff --git a/timm/models/__pycache__/senet.cpython-38.pyc b/timm/models/__pycache__/senet.cpython-38.pyc index 3b40825..b4950aa 100644 Binary files a/timm/models/__pycache__/senet.cpython-38.pyc and b/timm/models/__pycache__/senet.cpython-38.pyc differ diff --git a/timm/models/__pycache__/sknet.cpython-38.pyc b/timm/models/__pycache__/sknet.cpython-38.pyc index 91691f7..8592a3d 100644 Binary files a/timm/models/__pycache__/sknet.cpython-38.pyc and b/timm/models/__pycache__/sknet.cpython-38.pyc differ diff --git a/timm/models/__pycache__/swin_transformer.cpython-38.pyc b/timm/models/__pycache__/swin_transformer.cpython-38.pyc index 6e880e0..ebdf01c 100644 Binary files a/timm/models/__pycache__/swin_transformer.cpython-38.pyc and b/timm/models/__pycache__/swin_transformer.cpython-38.pyc differ diff --git a/timm/models/__pycache__/tnt.cpython-38.pyc b/timm/models/__pycache__/tnt.cpython-38.pyc index a481b18..08b0f48 100644 Binary files a/timm/models/__pycache__/tnt.cpython-38.pyc and b/timm/models/__pycache__/tnt.cpython-38.pyc differ diff --git a/timm/models/__pycache__/tresnet.cpython-38.pyc b/timm/models/__pycache__/tresnet.cpython-38.pyc index 6465eae..9a11b98 100644 Binary files a/timm/models/__pycache__/tresnet.cpython-38.pyc and b/timm/models/__pycache__/tresnet.cpython-38.pyc differ diff --git a/timm/models/__pycache__/twins.cpython-38.pyc b/timm/models/__pycache__/twins.cpython-38.pyc index a434d3b..0aa9365 100644 Binary files a/timm/models/__pycache__/twins.cpython-38.pyc and b/timm/models/__pycache__/twins.cpython-38.pyc differ diff --git a/timm/models/__pycache__/vgg.cpython-38.pyc b/timm/models/__pycache__/vgg.cpython-38.pyc index c3531b9..5ccd1ef 100644 Binary files a/timm/models/__pycache__/vgg.cpython-38.pyc and b/timm/models/__pycache__/vgg.cpython-38.pyc differ diff --git a/timm/models/__pycache__/visformer.cpython-38.pyc b/timm/models/__pycache__/visformer.cpython-38.pyc index 21f80c5..3fd4dd7 100644 Binary files a/timm/models/__pycache__/visformer.cpython-38.pyc and b/timm/models/__pycache__/visformer.cpython-38.pyc differ diff --git a/timm/models/__pycache__/vision_transformer.cpython-38.pyc b/timm/models/__pycache__/vision_transformer.cpython-38.pyc index 4fb15a5..cafea8d 100644 Binary files a/timm/models/__pycache__/vision_transformer.cpython-38.pyc and b/timm/models/__pycache__/vision_transformer.cpython-38.pyc differ diff --git a/timm/models/__pycache__/vision_transformer_hybrid.cpython-38.pyc b/timm/models/__pycache__/vision_transformer_hybrid.cpython-38.pyc index d81c894..011ea2b 100644 Binary files a/timm/models/__pycache__/vision_transformer_hybrid.cpython-38.pyc and b/timm/models/__pycache__/vision_transformer_hybrid.cpython-38.pyc differ diff --git a/timm/models/__pycache__/vovnet.cpython-38.pyc b/timm/models/__pycache__/vovnet.cpython-38.pyc index 3af75ff..ef564db 100644 Binary files a/timm/models/__pycache__/vovnet.cpython-38.pyc and b/timm/models/__pycache__/vovnet.cpython-38.pyc differ diff --git a/timm/models/__pycache__/xception.cpython-38.pyc b/timm/models/__pycache__/xception.cpython-38.pyc index 5ac792d..3474e42 100644 Binary files a/timm/models/__pycache__/xception.cpython-38.pyc and b/timm/models/__pycache__/xception.cpython-38.pyc differ diff --git a/timm/models/__pycache__/xception_aligned.cpython-38.pyc b/timm/models/__pycache__/xception_aligned.cpython-38.pyc index 2510c81..cc8c214 100644 Binary files a/timm/models/__pycache__/xception_aligned.cpython-38.pyc and b/timm/models/__pycache__/xception_aligned.cpython-38.pyc differ diff --git a/timm/models/layers/__pycache__/__init__.cpython-38.pyc b/timm/models/layers/__pycache__/__init__.cpython-38.pyc index 19824e6..ac8e712 100644 Binary files a/timm/models/layers/__pycache__/__init__.cpython-38.pyc and b/timm/models/layers/__pycache__/__init__.cpython-38.pyc differ diff --git a/timm/models/layers/__pycache__/activations.cpython-38.pyc b/timm/models/layers/__pycache__/activations.cpython-38.pyc index 77994f7..a1abb2d 100644 Binary files a/timm/models/layers/__pycache__/activations.cpython-38.pyc and b/timm/models/layers/__pycache__/activations.cpython-38.pyc differ diff --git a/timm/models/layers/__pycache__/activations_jit.cpython-38.pyc b/timm/models/layers/__pycache__/activations_jit.cpython-38.pyc index c2cb8b5..0fd6bf9 100644 Binary files a/timm/models/layers/__pycache__/activations_jit.cpython-38.pyc and b/timm/models/layers/__pycache__/activations_jit.cpython-38.pyc differ diff --git a/timm/models/layers/__pycache__/activations_me.cpython-38.pyc b/timm/models/layers/__pycache__/activations_me.cpython-38.pyc index 6b523c1..945a14d 100644 Binary files a/timm/models/layers/__pycache__/activations_me.cpython-38.pyc and b/timm/models/layers/__pycache__/activations_me.cpython-38.pyc differ diff --git a/timm/models/layers/__pycache__/adaptive_avgmax_pool.cpython-38.pyc b/timm/models/layers/__pycache__/adaptive_avgmax_pool.cpython-38.pyc index 5962bcc..b447ff1 100644 Binary files a/timm/models/layers/__pycache__/adaptive_avgmax_pool.cpython-38.pyc and b/timm/models/layers/__pycache__/adaptive_avgmax_pool.cpython-38.pyc differ diff --git a/timm/models/layers/__pycache__/blur_pool.cpython-38.pyc b/timm/models/layers/__pycache__/blur_pool.cpython-38.pyc index 6cde300..98df939 100644 Binary files a/timm/models/layers/__pycache__/blur_pool.cpython-38.pyc and b/timm/models/layers/__pycache__/blur_pool.cpython-38.pyc differ diff --git a/timm/models/layers/__pycache__/bottleneck_attn.cpython-38.pyc b/timm/models/layers/__pycache__/bottleneck_attn.cpython-38.pyc index 1bd7002..8c54f8e 100644 Binary files a/timm/models/layers/__pycache__/bottleneck_attn.cpython-38.pyc and b/timm/models/layers/__pycache__/bottleneck_attn.cpython-38.pyc differ diff --git a/timm/models/layers/__pycache__/cbam.cpython-38.pyc b/timm/models/layers/__pycache__/cbam.cpython-38.pyc index f4d7a79..f5daae5 100644 Binary files a/timm/models/layers/__pycache__/cbam.cpython-38.pyc and b/timm/models/layers/__pycache__/cbam.cpython-38.pyc differ diff --git a/timm/models/layers/__pycache__/classifier.cpython-38.pyc b/timm/models/layers/__pycache__/classifier.cpython-38.pyc index c865a1f..b15ab38 100644 Binary files a/timm/models/layers/__pycache__/classifier.cpython-38.pyc and b/timm/models/layers/__pycache__/classifier.cpython-38.pyc differ diff --git a/timm/models/layers/__pycache__/cond_conv2d.cpython-38.pyc b/timm/models/layers/__pycache__/cond_conv2d.cpython-38.pyc index 741a55c..2243b78 100644 Binary files a/timm/models/layers/__pycache__/cond_conv2d.cpython-38.pyc and b/timm/models/layers/__pycache__/cond_conv2d.cpython-38.pyc differ diff --git a/timm/models/layers/__pycache__/config.cpython-38.pyc b/timm/models/layers/__pycache__/config.cpython-38.pyc index d6a2aa6..b4d0bb6 100644 Binary files a/timm/models/layers/__pycache__/config.cpython-38.pyc and b/timm/models/layers/__pycache__/config.cpython-38.pyc differ diff --git a/timm/models/layers/__pycache__/conv2d_same.cpython-38.pyc b/timm/models/layers/__pycache__/conv2d_same.cpython-38.pyc index 02d1dfb..19a812b 100644 Binary files a/timm/models/layers/__pycache__/conv2d_same.cpython-38.pyc and b/timm/models/layers/__pycache__/conv2d_same.cpython-38.pyc differ diff --git a/timm/models/layers/__pycache__/conv_bn_act.cpython-38.pyc b/timm/models/layers/__pycache__/conv_bn_act.cpython-38.pyc index 9fe72a9..296c40f 100644 Binary files a/timm/models/layers/__pycache__/conv_bn_act.cpython-38.pyc and b/timm/models/layers/__pycache__/conv_bn_act.cpython-38.pyc differ diff --git a/timm/models/layers/__pycache__/create_act.cpython-38.pyc b/timm/models/layers/__pycache__/create_act.cpython-38.pyc index 7609ed8..2774327 100644 Binary files a/timm/models/layers/__pycache__/create_act.cpython-38.pyc and b/timm/models/layers/__pycache__/create_act.cpython-38.pyc differ diff --git a/timm/models/layers/__pycache__/create_attn.cpython-38.pyc b/timm/models/layers/__pycache__/create_attn.cpython-38.pyc index 52f06ed..0d5e01a 100644 Binary files a/timm/models/layers/__pycache__/create_attn.cpython-38.pyc and b/timm/models/layers/__pycache__/create_attn.cpython-38.pyc differ diff --git a/timm/models/layers/__pycache__/create_conv2d.cpython-38.pyc b/timm/models/layers/__pycache__/create_conv2d.cpython-38.pyc index de20d60..41900df 100644 Binary files a/timm/models/layers/__pycache__/create_conv2d.cpython-38.pyc and b/timm/models/layers/__pycache__/create_conv2d.cpython-38.pyc differ diff --git a/timm/models/layers/__pycache__/create_norm_act.cpython-38.pyc b/timm/models/layers/__pycache__/create_norm_act.cpython-38.pyc index e3930d6..64c252f 100644 Binary files a/timm/models/layers/__pycache__/create_norm_act.cpython-38.pyc and b/timm/models/layers/__pycache__/create_norm_act.cpython-38.pyc differ diff --git a/timm/models/layers/__pycache__/drop.cpython-38.pyc b/timm/models/layers/__pycache__/drop.cpython-38.pyc index 56af0fb..111bdbf 100644 Binary files a/timm/models/layers/__pycache__/drop.cpython-38.pyc and b/timm/models/layers/__pycache__/drop.cpython-38.pyc differ diff --git a/timm/models/layers/__pycache__/eca.cpython-38.pyc b/timm/models/layers/__pycache__/eca.cpython-38.pyc index 637aec2..c2faa81 100644 Binary files a/timm/models/layers/__pycache__/eca.cpython-38.pyc and b/timm/models/layers/__pycache__/eca.cpython-38.pyc differ diff --git a/timm/models/layers/__pycache__/evo_norm.cpython-38.pyc b/timm/models/layers/__pycache__/evo_norm.cpython-38.pyc index 3a52285..c270325 100644 Binary files a/timm/models/layers/__pycache__/evo_norm.cpython-38.pyc and b/timm/models/layers/__pycache__/evo_norm.cpython-38.pyc differ diff --git a/timm/models/layers/__pycache__/gather_excite.cpython-38.pyc b/timm/models/layers/__pycache__/gather_excite.cpython-38.pyc index 38f1187..cec6960 100644 Binary files a/timm/models/layers/__pycache__/gather_excite.cpython-38.pyc and b/timm/models/layers/__pycache__/gather_excite.cpython-38.pyc differ diff --git a/timm/models/layers/__pycache__/global_context.cpython-38.pyc b/timm/models/layers/__pycache__/global_context.cpython-38.pyc index 509738a..88e3df1 100644 Binary files a/timm/models/layers/__pycache__/global_context.cpython-38.pyc and b/timm/models/layers/__pycache__/global_context.cpython-38.pyc differ diff --git a/timm/models/layers/__pycache__/halo_attn.cpython-38.pyc b/timm/models/layers/__pycache__/halo_attn.cpython-38.pyc index 281ea3c..785bd38 100644 Binary files a/timm/models/layers/__pycache__/halo_attn.cpython-38.pyc and b/timm/models/layers/__pycache__/halo_attn.cpython-38.pyc differ diff --git a/timm/models/layers/__pycache__/helpers.cpython-38.pyc b/timm/models/layers/__pycache__/helpers.cpython-38.pyc index 64e760b..584d292 100644 Binary files a/timm/models/layers/__pycache__/helpers.cpython-38.pyc and b/timm/models/layers/__pycache__/helpers.cpython-38.pyc differ diff --git a/timm/models/layers/__pycache__/inplace_abn.cpython-38.pyc b/timm/models/layers/__pycache__/inplace_abn.cpython-38.pyc index e1b2e66..35c1fa1 100644 Binary files a/timm/models/layers/__pycache__/inplace_abn.cpython-38.pyc and b/timm/models/layers/__pycache__/inplace_abn.cpython-38.pyc differ diff --git a/timm/models/layers/__pycache__/involution.cpython-38.pyc b/timm/models/layers/__pycache__/involution.cpython-38.pyc index 05942ea..81857c2 100644 Binary files a/timm/models/layers/__pycache__/involution.cpython-38.pyc and b/timm/models/layers/__pycache__/involution.cpython-38.pyc differ diff --git a/timm/models/layers/__pycache__/lambda_layer.cpython-38.pyc b/timm/models/layers/__pycache__/lambda_layer.cpython-38.pyc index f9fe226..09990cf 100644 Binary files a/timm/models/layers/__pycache__/lambda_layer.cpython-38.pyc and b/timm/models/layers/__pycache__/lambda_layer.cpython-38.pyc differ diff --git a/timm/models/layers/__pycache__/linear.cpython-38.pyc b/timm/models/layers/__pycache__/linear.cpython-38.pyc index 3a1be6b..1b793c0 100644 Binary files a/timm/models/layers/__pycache__/linear.cpython-38.pyc and b/timm/models/layers/__pycache__/linear.cpython-38.pyc differ diff --git a/timm/models/layers/__pycache__/mixed_conv2d.cpython-38.pyc b/timm/models/layers/__pycache__/mixed_conv2d.cpython-38.pyc index 1058ba9..5103166 100644 Binary files a/timm/models/layers/__pycache__/mixed_conv2d.cpython-38.pyc and b/timm/models/layers/__pycache__/mixed_conv2d.cpython-38.pyc differ diff --git a/timm/models/layers/__pycache__/mlp.cpython-38.pyc b/timm/models/layers/__pycache__/mlp.cpython-38.pyc index ae8fb3c..24aeb2c 100644 Binary files a/timm/models/layers/__pycache__/mlp.cpython-38.pyc and b/timm/models/layers/__pycache__/mlp.cpython-38.pyc differ diff --git a/timm/models/layers/__pycache__/non_local_attn.cpython-38.pyc b/timm/models/layers/__pycache__/non_local_attn.cpython-38.pyc index 9808605..83006ba 100644 Binary files a/timm/models/layers/__pycache__/non_local_attn.cpython-38.pyc and b/timm/models/layers/__pycache__/non_local_attn.cpython-38.pyc differ diff --git a/timm/models/layers/__pycache__/norm.cpython-38.pyc b/timm/models/layers/__pycache__/norm.cpython-38.pyc index 68503b9..9b347db 100644 Binary files a/timm/models/layers/__pycache__/norm.cpython-38.pyc and b/timm/models/layers/__pycache__/norm.cpython-38.pyc differ diff --git a/timm/models/layers/__pycache__/norm_act.cpython-38.pyc b/timm/models/layers/__pycache__/norm_act.cpython-38.pyc index c289caa..4cd4f1d 100644 Binary files a/timm/models/layers/__pycache__/norm_act.cpython-38.pyc and b/timm/models/layers/__pycache__/norm_act.cpython-38.pyc differ diff --git a/timm/models/layers/__pycache__/padding.cpython-38.pyc b/timm/models/layers/__pycache__/padding.cpython-38.pyc index c07b21b..5cf0a03 100644 Binary files a/timm/models/layers/__pycache__/padding.cpython-38.pyc and b/timm/models/layers/__pycache__/padding.cpython-38.pyc differ diff --git a/timm/models/layers/__pycache__/patch_embed.cpython-38.pyc b/timm/models/layers/__pycache__/patch_embed.cpython-38.pyc index c1f0e92..2dfbad1 100644 Binary files a/timm/models/layers/__pycache__/patch_embed.cpython-38.pyc and b/timm/models/layers/__pycache__/patch_embed.cpython-38.pyc differ diff --git a/timm/models/layers/__pycache__/pool2d_same.cpython-38.pyc b/timm/models/layers/__pycache__/pool2d_same.cpython-38.pyc index 9866f76..ca9e814 100644 Binary files a/timm/models/layers/__pycache__/pool2d_same.cpython-38.pyc and b/timm/models/layers/__pycache__/pool2d_same.cpython-38.pyc differ diff --git a/timm/models/layers/__pycache__/selective_kernel.cpython-38.pyc b/timm/models/layers/__pycache__/selective_kernel.cpython-38.pyc index 8c11e06..3a2c247 100644 Binary files a/timm/models/layers/__pycache__/selective_kernel.cpython-38.pyc and b/timm/models/layers/__pycache__/selective_kernel.cpython-38.pyc differ diff --git a/timm/models/layers/__pycache__/separable_conv.cpython-38.pyc b/timm/models/layers/__pycache__/separable_conv.cpython-38.pyc index 8606c96..dcab611 100644 Binary files a/timm/models/layers/__pycache__/separable_conv.cpython-38.pyc and b/timm/models/layers/__pycache__/separable_conv.cpython-38.pyc differ diff --git a/timm/models/layers/__pycache__/space_to_depth.cpython-38.pyc b/timm/models/layers/__pycache__/space_to_depth.cpython-38.pyc index fb69fea..2bbb68a 100644 Binary files a/timm/models/layers/__pycache__/space_to_depth.cpython-38.pyc and b/timm/models/layers/__pycache__/space_to_depth.cpython-38.pyc differ diff --git a/timm/models/layers/__pycache__/split_attn.cpython-38.pyc b/timm/models/layers/__pycache__/split_attn.cpython-38.pyc index 99095ae..300839d 100644 Binary files a/timm/models/layers/__pycache__/split_attn.cpython-38.pyc and b/timm/models/layers/__pycache__/split_attn.cpython-38.pyc differ diff --git a/timm/models/layers/__pycache__/split_batchnorm.cpython-38.pyc b/timm/models/layers/__pycache__/split_batchnorm.cpython-38.pyc index 29f8913..e2de265 100644 Binary files a/timm/models/layers/__pycache__/split_batchnorm.cpython-38.pyc and b/timm/models/layers/__pycache__/split_batchnorm.cpython-38.pyc differ diff --git a/timm/models/layers/__pycache__/squeeze_excite.cpython-38.pyc b/timm/models/layers/__pycache__/squeeze_excite.cpython-38.pyc index 2835eef..0b2602b 100644 Binary files a/timm/models/layers/__pycache__/squeeze_excite.cpython-38.pyc and b/timm/models/layers/__pycache__/squeeze_excite.cpython-38.pyc differ diff --git a/timm/models/layers/__pycache__/std_conv.cpython-38.pyc b/timm/models/layers/__pycache__/std_conv.cpython-38.pyc index 6d65a83..19bbe0e 100644 Binary files a/timm/models/layers/__pycache__/std_conv.cpython-38.pyc and b/timm/models/layers/__pycache__/std_conv.cpython-38.pyc differ diff --git a/timm/models/layers/__pycache__/swin_attn.cpython-38.pyc b/timm/models/layers/__pycache__/swin_attn.cpython-38.pyc index ef15818..5f4759d 100644 Binary files a/timm/models/layers/__pycache__/swin_attn.cpython-38.pyc and b/timm/models/layers/__pycache__/swin_attn.cpython-38.pyc differ diff --git a/timm/models/layers/__pycache__/test_time_pool.cpython-38.pyc b/timm/models/layers/__pycache__/test_time_pool.cpython-38.pyc index 501c604..3089355 100644 Binary files a/timm/models/layers/__pycache__/test_time_pool.cpython-38.pyc and b/timm/models/layers/__pycache__/test_time_pool.cpython-38.pyc differ diff --git a/timm/models/layers/__pycache__/weight_init.cpython-38.pyc b/timm/models/layers/__pycache__/weight_init.cpython-38.pyc index eff9085..8b17d65 100644 Binary files a/timm/models/layers/__pycache__/weight_init.cpython-38.pyc and b/timm/models/layers/__pycache__/weight_init.cpython-38.pyc differ diff --git a/timm/utils/__pycache__/__init__.cpython-38.pyc b/timm/utils/__pycache__/__init__.cpython-38.pyc index e8a4e15..1571dc3 100644 Binary files a/timm/utils/__pycache__/__init__.cpython-38.pyc and b/timm/utils/__pycache__/__init__.cpython-38.pyc differ diff --git a/timm/utils/__pycache__/agc.cpython-38.pyc b/timm/utils/__pycache__/agc.cpython-38.pyc index 3799261..c929235 100644 Binary files a/timm/utils/__pycache__/agc.cpython-38.pyc and b/timm/utils/__pycache__/agc.cpython-38.pyc differ diff --git a/timm/utils/__pycache__/checkpoint_saver.cpython-38.pyc b/timm/utils/__pycache__/checkpoint_saver.cpython-38.pyc index b6950d3..7c25864 100644 Binary files a/timm/utils/__pycache__/checkpoint_saver.cpython-38.pyc and b/timm/utils/__pycache__/checkpoint_saver.cpython-38.pyc differ diff --git a/timm/utils/__pycache__/clip_grad.cpython-38.pyc b/timm/utils/__pycache__/clip_grad.cpython-38.pyc index c0d7413..6456b45 100644 Binary files a/timm/utils/__pycache__/clip_grad.cpython-38.pyc and b/timm/utils/__pycache__/clip_grad.cpython-38.pyc differ diff --git a/timm/utils/__pycache__/cuda.cpython-38.pyc b/timm/utils/__pycache__/cuda.cpython-38.pyc index 1a96cf6..737118b 100644 Binary files a/timm/utils/__pycache__/cuda.cpython-38.pyc and b/timm/utils/__pycache__/cuda.cpython-38.pyc differ diff --git a/timm/utils/__pycache__/distributed.cpython-38.pyc b/timm/utils/__pycache__/distributed.cpython-38.pyc index 54dd3bd..0601de4 100644 Binary files a/timm/utils/__pycache__/distributed.cpython-38.pyc and b/timm/utils/__pycache__/distributed.cpython-38.pyc differ diff --git a/timm/utils/__pycache__/jit.cpython-38.pyc b/timm/utils/__pycache__/jit.cpython-38.pyc index e95ad8c..ffe6335 100644 Binary files a/timm/utils/__pycache__/jit.cpython-38.pyc and b/timm/utils/__pycache__/jit.cpython-38.pyc differ diff --git a/timm/utils/__pycache__/log.cpython-38.pyc b/timm/utils/__pycache__/log.cpython-38.pyc index d03fd4f..7e151aa 100644 Binary files a/timm/utils/__pycache__/log.cpython-38.pyc and b/timm/utils/__pycache__/log.cpython-38.pyc differ diff --git a/timm/utils/__pycache__/metrics.cpython-38.pyc b/timm/utils/__pycache__/metrics.cpython-38.pyc index 06bdf4a..df9b4c2 100644 Binary files a/timm/utils/__pycache__/metrics.cpython-38.pyc and b/timm/utils/__pycache__/metrics.cpython-38.pyc differ diff --git a/timm/utils/__pycache__/misc.cpython-38.pyc b/timm/utils/__pycache__/misc.cpython-38.pyc index 28ae191..3a73527 100644 Binary files a/timm/utils/__pycache__/misc.cpython-38.pyc and b/timm/utils/__pycache__/misc.cpython-38.pyc differ diff --git a/timm/utils/__pycache__/model.cpython-38.pyc b/timm/utils/__pycache__/model.cpython-38.pyc index 4b4fa4a..5ea2211 100644 Binary files a/timm/utils/__pycache__/model.cpython-38.pyc and b/timm/utils/__pycache__/model.cpython-38.pyc differ diff --git a/timm/utils/__pycache__/model_ema.cpython-38.pyc b/timm/utils/__pycache__/model_ema.cpython-38.pyc index e6e4879..31288e6 100644 Binary files a/timm/utils/__pycache__/model_ema.cpython-38.pyc and b/timm/utils/__pycache__/model_ema.cpython-38.pyc differ diff --git a/timm/utils/__pycache__/random.cpython-38.pyc b/timm/utils/__pycache__/random.cpython-38.pyc index a163bf9..1a8aa92 100644 Binary files a/timm/utils/__pycache__/random.cpython-38.pyc and b/timm/utils/__pycache__/random.cpython-38.pyc differ diff --git a/timm/utils/__pycache__/summary.cpython-38.pyc b/timm/utils/__pycache__/summary.cpython-38.pyc index 652410e..962f95e 100644 Binary files a/timm/utils/__pycache__/summary.cpython-38.pyc and b/timm/utils/__pycache__/summary.cpython-38.pyc differ diff --git a/utils/.ipynb_checkpoints/config_utils-checkpoint.py b/utils/.ipynb_checkpoints/config_utils-checkpoint.py new file mode 100644 index 0000000..1d090c6 --- /dev/null +++ b/utils/.ipynb_checkpoints/config_utils-checkpoint.py @@ -0,0 +1,72 @@ +import yaml +import os +import argparse + +def load_yaml(args, yml): + with open(yml, 'r', encoding='utf-8') as fyml: + dic = yaml.load(fyml.read(), Loader=yaml.Loader) + for k in dic: + setattr(args, k, dic[k]) + +def build_record_folder(args): + + if not os.path.isdir("./records/"): + os.mkdir("./records/") + + args.save_dir = "./records/" + args.project_name + "/" + args.exp_name + "/" + os.makedirs(args.save_dir, exist_ok=True) + os.makedirs(args.save_dir + "backup/", exist_ok=True) + +def get_args(with_deepspeed: bool=False): + + parser = argparse.ArgumentParser("Fine-Grained Visual Classification") + + parser.add_argument("--project_name", default="") + parser.add_argument("--exp_name", default="") + + parser.add_argument("--c", default="", type=str, help="config file path") + + ### about dataset + parser.add_argument("--train_root", default="", type=str) # "../NABirds/train/" + parser.add_argument("--val_root", default="", type=str) + parser.add_argument("--data_size", default=384, type=int) + parser.add_argument("--num_workers", default=2, type=int) + parser.add_argument("--batch_size", default=64, type=int) + + ### model + parser.add_argument("--model_name", default="", type=str, help='["resnet50", "swin-t", "vit", "efficient"]') + parser.add_argument("--optimizer", default="", type=str, help='["SGD", "AdamW"]') + parser.add_argument("--max_lr", default=0.0003, type=float) + parser.add_argument("--wdecay", default=0.0005, type=float) + + parser.add_argument("--max_epochs", default=50, type=int) + parser.add_argument("--warmup_batchs", default=0, type=int) + + parser.add_argument("--use_fpn", default=True, type=bool) + parser.add_argument("--fpn_size", default=512, type=int) + parser.add_argument("--use_selection", default=True, type=bool) + parser.add_argument("--num_classes", default=10, type=int) + parser.add_argument("--num_selects", default={ + "layer1":32, + "layer2":32, + "layer3":32, + "layer4":32 + }, type=dict) + parser.add_argument("--use_combiner", default=True, type=bool) + + ### loss + parser.add_argument("--lambda_b", default=0.5, type=float) + parser.add_argument("--lambda_s", default=0.0, type=float) + parser.add_argument("--lambda_n", default=5.0, type=float) + parser.add_argument("--lambda_c", default=1.0, type=float) + + parser.add_argument("--use_wandb", default=True, type=bool) + + if with_deepspeed: + import deepspeed + parser = deepspeed.add_config_arguments(parser) + + args = parser.parse_args() + + return args + diff --git a/utils/__pycache__/config_utils.cpython-38.pyc b/utils/__pycache__/config_utils.cpython-38.pyc index aa4f8f4..9c2999d 100644 Binary files a/utils/__pycache__/config_utils.cpython-38.pyc and b/utils/__pycache__/config_utils.cpython-38.pyc differ diff --git a/utils/__pycache__/costom_logger.cpython-38.pyc b/utils/__pycache__/costom_logger.cpython-38.pyc index 34a0fa2..7411701 100644 Binary files a/utils/__pycache__/costom_logger.cpython-38.pyc and b/utils/__pycache__/costom_logger.cpython-38.pyc differ diff --git a/utils/__pycache__/lr_schedule.cpython-38.pyc b/utils/__pycache__/lr_schedule.cpython-38.pyc index 3e4e052..ebddf47 100644 Binary files a/utils/__pycache__/lr_schedule.cpython-38.pyc and b/utils/__pycache__/lr_schedule.cpython-38.pyc differ diff --git a/vit_pim_main.py b/vit_pim_main.py new file mode 100644 index 0000000..1065083 --- /dev/null +++ b/vit_pim_main.py @@ -0,0 +1,428 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import contextlib +import wandb +import warnings + +from models.builder import MODEL_GETTER +from data.dataset import build_loader +from utils.costom_logger import timeLogger +from utils.config_utils import load_yaml, build_record_folder, get_args +from utils.lr_schedule import cosine_decay, adjust_lr, get_lr +from eval import evaluate, cal_train_metrics + +warnings.simplefilter("ignore") + + +def eval_freq_schedule(args, epoch: int): + """ + 根据当前训练的 epoch 调整验证频率(eval_freq)。 + 在训练接近尾声时更频繁地进行验证,以便更好地监控模型性能。 + + 参数: + args: 包含训练配置参数的对象,其中包括 eval_freq 和 max_epochs。 + epoch: 当前训练的 epoch 数。 + """ + # 如果当前 epoch 大于等于最大训练轮次的 95%,则将验证频率设为 1(每个 epoch 都验证) + if epoch >= args.max_epochs * 0.95: + args.eval_freq = 1 + # 如果当前 epoch 大于等于最大训练轮次的 90% 但小于 95%,同样将验证频率设为 1 + elif epoch >= args.max_epochs * 0.9: + args.eval_freq = 1 + # 如果当前 epoch 大于等于最大训练轮次的 80% 但小于 90%,将验证频率设为 2(每两个 epoch 验证一次) + elif epoch >= args.max_epochs * 0.8: + args.eval_freq = 2 + + +def set_environment(args, tlogger): + """ + 设置训练环境,包括设备、数据加载器、模型、优化器等。 + + 参数: + args: 包含训练配置参数的对象。 + tlogger: 用于记录时间日志的对象。 + + 返回: + train_loader: 训练数据加载器。 + val_loader: 验证数据加载器。 + model: 构建并初始化的模型。 + optimizer: 优化器(如果仅评估则为None)。 + schedule: 学习率调度器(如果仅评估则为None)。 + scaler: AMP缩放器(如果不使用AMP则为None)。 + amp_context: AMP上下文管理器(如果不使用AMP则是nullcontext)。 + start_epoch: 训练开始的epoch数(如果有预训练模型,则从该模型的epoch开始)。 + """ + + print("Setting Environment...") + + # 设置训练设备:如果CUDA可用则使用GPU,否则使用CPU + args.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + + ### = = = = Dataset and Data Loader = = = = + # 构建训练和验证数据加载器 + tlogger.print("Building Dataloader....") + + train_loader, val_loader = build_loader(args) + + # 检查是否成功构建了数据加载器 + if train_loader is None and val_loader is None: + raise ValueError("Find nothing to train or evaluate.") + + # 打印训练集信息 + if train_loader is not None: + print(" Train Samples: {} (batch: {})".format(len(train_loader.dataset), len(train_loader))) + else: + # raise ValueError("Build train loader fail, please provide legal path.") + print(" Train Samples: 0 ~~~~~> [Only Evaluation]") + + # 打印验证集信息 + if val_loader is not None: + print(" Validation Samples: {} (batch: {})".format(len(val_loader.dataset), len(val_loader))) + else: + print(" Validation Samples: 0 ~~~~~> [Only Training]") + tlogger.print() + + ### = = = = Model = = = = + # 构建模型 + tlogger.print("Building Model....") + model = MODEL_GETTER[args.model_name]( + use_fpn=args.use_fpn, + fpn_size=args.fpn_size, + use_selection=args.use_selection, + num_classes=args.num_classes, + num_selects=args.num_selects, + use_combiner=args.use_combiner, + ) # about return_nodes, we use our default setting + + # 如果提供了预训练模型,则加载权重 + if args.pretrained is not None: + checkpoint = torch.load(args.pretrained, map_location=torch.device('cpu')) + model.load_state_dict(checkpoint['model']) + start_epoch = checkpoint['epoch'] + print(start_epoch) + else: + start_epoch = 0 + + # 将模型移动到指定设备 + model.to(args.device) + tlogger.print() + + """ + 如果你有多GPU设备,可以在单机多GPU情况下使用torch.nn.DataParallel, + 或者使用torch.nn.parallel.DistributedDataParallel实现多进程并行。 + 更多详情:https://pytorch.org/tutorials/beginner/dist_overview.html + """ + + # 如果没有训练数据加载器,只进行评估,返回部分对象 + if train_loader is None: + return train_loader, val_loader, model, None, None, None, None, start_epoch + + ### = = = = Optimizer = = = = + # 构建优化器 + tlogger.print("Building Optimizer....") + if args.optimizer == "SGD": + optimizer = torch.optim.SGD(model.parameters(), lr=args.max_lr, nesterov=True, momentum=0.9, + weight_decay=args.wdecay) + elif args.optimizer == "AdamW": + optimizer = torch.optim.AdamW(model.parameters(), lr=args.max_lr) + + # 如果有预训练模型,加载优化器状态 + if args.pretrained is not None: + optimizer.load_state_dict(checkpoint['optimizer']) + + tlogger.print() + + # 构建学习率调度器 + schedule = cosine_decay(args, len(train_loader)) + + # 如果使用混合精度训练(AMP),设置相关的组件 + if args.use_amp: + scaler = torch.cuda.amp.GradScaler() + amp_context = torch.cuda.amp.autocast + else: + scaler = None + amp_context = contextlib.nullcontext + + # 返回所有构建的组件 + return train_loader, val_loader, model, optimizer, schedule, scaler, amp_context, start_epoch + + +def train(args, epoch, model, scaler, amp_context, optimizer, schedule, train_loader): + """ + 训练函数,在每个epoch中迭代训练数据并更新模型参数。 + + 参数: + args: 包含训练配置参数的对象。 + epoch: 当前训练的 epoch 数。 + model: 要训练的模型。 + scaler: AMP缩放器(如果不使用AMP则为None)。 + amp_context: AMP上下文管理器(如果不使用AMP则是nullcontext)。 + optimizer: 优化器。 + schedule: 学习率调度器。 + train_loader: 训练数据加载器。 + """ + + # 清空优化器的梯度 + optimizer.zero_grad() + + # 获取总批次数,仅用于日志记录 + total_batchs = len(train_loader) + + # 定义训练进度显示点(0%, 10%, ..., 100%) + show_progress = [x / 10 for x in range(11)] + progress_i = 0 + + # 遍历训练数据加载器中的每个批次 + for batch_id, (ids, datas, labels) in enumerate(train_loader): + # 设置模型为训练模式 + model.train() + + """ = = = = adjust learning rate = = = = """ + # 计算当前迭代次数 + iterations = epoch * len(train_loader) + batch_id + # 调整学习率 + adjust_lr(iterations, optimizer, schedule) + + # 获取当前批次的样本数量 + batch_size = labels.size(0) + + """ = = = = forward and calculate loss = = = = """ + # 将数据和标签移动到指定设备 + datas, labels = datas.to(args.device), labels.to(args.device) + + # 使用AMP上下文进行前向传播(如果启用AMP) + with amp_context(): + """ + [Model Return] + FPN + Selector + Combiner --> return 'layer1', 'layer2', 'layer3', 'layer4', ...(depend on your setting) + 'preds_0', 'preds_1', 'comb_outs' + FPN + Selector --> return 'layer1', 'layer2', 'layer3', 'layer4', ...(depend on your setting) + 'preds_0', 'preds_1' + FPN --> return 'layer1', 'layer2', 'layer3', 'layer4' (depend on your setting) + ~ --> return 'ori_out' + + [Retuen Tensor] + 'preds_0': logit has not been selected by Selector. + 'preds_1': logit has been selected by Selector. + 'comb_outs': The prediction of combiner. + """ + # 前向传播获取输出 + outs = model(datas) + + # 初始化总损失 + loss = 0. + + # 遍历模型输出的各个部分,计算相应的损失 + for name in outs: + # 处理选择器的输出 + if "select_" in name: + if not args.use_selection: + raise ValueError("Selector not use here.") + if args.lambda_s != 0: + # 计算选择器损失 + S = outs[name].size(1) + logit = outs[name].view(-1, args.num_classes).contiguous() + loss_s = nn.CrossEntropyLoss()(logit, + labels.unsqueeze(1).repeat(1, S).flatten(0)) + loss += args.lambda_s * loss_s + else: + loss_s = 0.0 + + # 处理丢弃部分的输出 + elif "drop_" in name: + if not args.use_selection: + raise ValueError("Selector not use here.") + + if args.lambda_n != 0: + # 计算负样本损失 + S = outs[name].size(1) + logit = outs[name].view(-1, args.num_classes).contiguous() + n_preds = nn.Tanh()(logit) + labels_0 = torch.zeros([batch_size * S, args.num_classes]) - 1 + labels_0 = labels_0.to(args.device) + loss_n = nn.MSELoss()(n_preds, labels_0) + loss += args.lambda_n * loss_n + else: + loss_n = 0.0 + + # 处理FPN层的输出 + elif "layer" in name: + if not args.use_fpn: + raise ValueError("FPN not use here.") + if args.lambda_b != 0: + # 计算FPN基础损失 + ### here using 'layer1'~'layer4' is default setting, you can change to your own + loss_b = nn.CrossEntropyLoss()(outs[name].mean(1), labels) + loss += args.lambda_b * loss_b + else: + loss_b = 0.0 + + # 处理组合器的输出 + elif "comb_outs" in name: + if not args.use_combiner: + raise ValueError("Combiner not use here.") + + if args.lambda_c != 0: + # 计算组合器损失 + loss_c = nn.CrossEntropyLoss()(outs[name], labels) + loss += args.lambda_c * loss_c + + # 处理原始输出 + elif "ori_out" in name: + # 计算原始输出损失 + loss_ori = F.cross_entropy(outs[name], labels) + loss += loss_ori + + # 对损失进行平均化处理 + loss /= args.update_freq + + """ = = = = calculate gradient = = = = """ + # 计算梯度(根据是否使用AMP选择不同的方式) + if args.use_amp: + scaler.scale(loss).backward() + else: + loss.backward() + + """ = = = = update model = = = = """ + # 更新模型参数(每隔update_freq个批次更新一次) + if (batch_id + 1) % args.update_freq == 0: + if args.use_amp: + # 使用AMP更新模型 + scaler.step(optimizer) + scaler.update() # next batch + else: + # 正常更新模型 + optimizer.step() + # 清空梯度 + optimizer.zero_grad() + + """ log (MISC) """ + # 记录训练日志(如果启用wandb且达到记录频率) + if args.use_wandb and ((batch_id + 1) % args.log_freq == 0): + # 切换到评估模式进行日志记录 + model.eval() + msg = {} + msg['info/epoch'] = epoch + 1 + msg['info/lr'] = get_lr(optimizer) + # 计算并记录训练指标 + cal_train_metrics(args, msg, outs, labels, batch_size) + # 将日志信息发送到wandb + wandb.log(msg) + + # 显示训练进度 + train_progress = (batch_id + 1) / total_batchs + # print(train_progress, show_progress[progress_i]) + if train_progress > show_progress[progress_i]: + print(".." + str(int(show_progress[progress_i] * 100)) + "%", end='', flush=True) + progress_i += 1 + +def main(args, tlogger): + """ + 主训练循环函数,负责整个训练和验证过程,包括模型保存(last.pt 和 best.pt)。 + + 参数: + args: 包含训练配置参数的对象。 + tlogger: 用于记录时间日志的对象。 + """ + + # 调用set_environment函数设置训练环境,获取数据加载器、模型、优化器等 + train_loader, val_loader, model, optimizer, schedule, scaler, amp_context, start_epoch = set_environment(args, + tlogger) + + # 初始化最佳准确率和最佳评估名称 + best_acc = 0.0 + best_eval_name = "null" + + # 如果启用wandb,则初始化wandb项目并设置初始摘要信息 + if args.use_wandb: + wandb.init(entity=args.wandb_entity, + project=args.project_name, + name=args.exp_name, + config=args) + wandb.run.summary["best_acc"] = best_acc + wandb.run.summary["best_eval_name"] = best_eval_name + wandb.run.summary["best_epoch"] = 0 + + # 开始训练循环,从start_epoch到max_epochs + for epoch in range(start_epoch, args.max_epochs): + + """ + 训练阶段 + """ + # 如果存在训练数据加载器,则进行训练 + if train_loader is not None: + tlogger.print("Start Training {} Epoch".format(epoch + 1)) + # 调用train函数进行一个epoch的训练 + train(args, epoch, model, scaler, amp_context, optimizer, schedule, train_loader) + tlogger.print() + else: + # 如果没有训练数据加载器(仅评估模式),则调用eval_and_save进行评估并保存结果,然后退出循环 + from eval import eval_and_save + eval_and_save(args, model, val_loader) + break + + # 根据当前epoch调整验证频率 + eval_freq_schedule(args, epoch) + + # 准备要保存的模型检查点(处理多GPU情况) + model_to_save = model.module if hasattr(model, "module") else model + checkpoint = {"model": model_to_save.state_dict(), "optimizer": optimizer.state_dict(), "epoch": epoch} + # 保存最新的模型检查点 + torch.save(checkpoint, args.save_dir + "backup/last.pt") + + # 根据评估频率进行验证(每个epoch或每隔几个epoch) + if epoch == 0 or (epoch + 1) % args.eval_freq == 0: + """ + 验证阶段 + """ + acc = -1 + # 如果存在验证数据加载器,则进行验证 + if val_loader is not None: + tlogger.print("Start Evaluating {} Epoch".format(epoch + 1)) + # 调用evaluate函数进行验证,获取准确率等信息 + acc, eval_name, accs = evaluate(args, model, val_loader) + # 打印当前验证结果和历史最佳准确率 + tlogger.print("....BEST_ACC: {}% ({}%)".format(max(acc, best_acc), acc)) + tlogger.print() + + # 如果启用wandb,则记录验证指标 + if args.use_wandb: + wandb.log(accs) + + # 如果当前准确率优于历史最佳准确率,则更新最佳准确率并保存最佳模型 + if acc > best_acc: + best_acc = acc + best_eval_name = eval_name + torch.save(checkpoint, args.save_dir + "backup/best.pt") + # 如果启用wandb,则更新wandb中的最佳指标摘要 + if args.use_wandb: + wandb.run.summary["best_acc"] = best_acc + wandb.run.summary["best_eval_name"] = best_eval_name + wandb.run.summary["best_epoch"] = epoch + 1 + + +if __name__ == "__main__": + # 创建一个时间记录器实例,用于记录和打印时间相关的日志 + tlogger = timeLogger() + + # 打印正在读取配置文件的信息 + tlogger.print("Reading Config...") + + # 获取命令行参数,这些参数包括配置文件路径等 + args = get_args() + + # 断言确保提供了配置文件(.yaml格式),如果没有提供则抛出错误信息 + assert args.c != "", "Please provide config file (.yaml)" + + # 加载指定的YAML配置文件,并将配置内容存入args对象中 + load_yaml(args, args.c) + + # 根据配置创建记录文件夹,用于保存训练过程中的日志、模型等文件 + build_record_folder(args) + + # 打印空行,起到分隔日志的作用 + tlogger.print() + + # 调用main函数开始执行主要的训练或评估流程,传入解析后的参数和时间记录器 + main(args, tlogger) \ No newline at end of file diff --git a/vit_pim_main_adamw.py b/vit_pim_main_adamw.py new file mode 100644 index 0000000..e2bafea --- /dev/null +++ b/vit_pim_main_adamw.py @@ -0,0 +1,291 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import contextlib +import wandb +import warnings + +from models.builder import MODEL_GETTER +from data.dataset import build_loader +from utils.costom_logger import timeLogger +from utils.config_utils import load_yaml, build_record_folder, get_args +from utils.lr_schedule import cosine_decay, adjust_lr, get_lr +from eval import evaluate, cal_train_metrics, _average_top_k_result # 确保导入_average_top_k_result + +warnings.simplefilter("ignore") + + +def eval_freq_schedule(args, epoch: int): + if epoch >= args.max_epochs * 0.95: + args.eval_freq = 1 + elif epoch >= args.max_epochs * 0.9: + args.eval_freq = 1 + elif epoch >= args.max_epochs * 0.8: + args.eval_freq = 2 + + +def set_environment(args, tlogger): + print("Setting Environment...") + args.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + + ### Dataset and Data Loader + tlogger.print("Building Dataloader....") + train_loader, val_loader = build_loader(args) + + if train_loader is None and val_loader is None: + raise ValueError("Find nothing to train or evaluate.") + + if train_loader is not None: + print(" Train Samples: {} (batch: {})".format(len(train_loader.dataset), len(train_loader))) + else: + print(" Train Samples: 0 ~~~~~> [Only Evaluation]") + if val_loader is not None: + print(" Validation Samples: {} (batch: {})".format(len(val_loader.dataset), len(val_loader))) + else: + print(" Validation Samples: 0 ~~~~~> [Only Training]") + tlogger.print() + + ### Model + tlogger.print("Building Model....") + model = MODEL_GETTER[args.model_name]( + use_fpn=args.use_fpn, + fpn_size=args.fpn_size, + use_selection=args.use_selection, + num_classes=args.num_classes, + num_selects=args.num_selects, + use_combiner=args.use_combiner, + ) + + start_epoch = 0 + if args.pretrained is not None: + checkpoint = torch.load(args.pretrained, map_location=torch.device('cpu')) + model.load_state_dict(checkpoint['model'], strict=False) + start_epoch = checkpoint['epoch'] + print(f"Loaded pretrained model from epoch {start_epoch}") + + model.to(args.device) + tlogger.print() + + if train_loader is None: + return train_loader, val_loader, model, None, None, None, None, start_epoch + + ### Optimizer + tlogger.print("Building Optimizer....") + if args.optimizer == "SGD": + optimizer = torch.optim.SGD(model.parameters(), lr=args.max_lr, nesterov=True, momentum=0.9, + weight_decay=args.wdecay) + elif args.optimizer == "AdamW": + optimizer = torch.optim.AdamW(model.parameters(), lr=args.max_lr) + + if args.pretrained is not None: + optimizer.load_state_dict(checkpoint['optimizer']) + + tlogger.print() + schedule = cosine_decay(args, len(train_loader)) + + ### AMP + if args.use_amp: + scaler = torch.cuda.amp.GradScaler() + amp_context = torch.cuda.amp.autocast + else: + scaler = None + amp_context = contextlib.nullcontext + + return train_loader, val_loader, model, optimizer, schedule, scaler, amp_context, start_epoch + + +def train(args, epoch, model, scaler, amp_context, optimizer, schedule, train_loader): + """修正训练Acc计算逻辑,打印训练Loss和Acc""" + optimizer.zero_grad() + total_batchs = len(train_loader) + show_progress = [x / 10 for x in range(11)] + progress_i = 0 + + # 记录整个epoch的训练指标 + epoch_total_loss = 0.0 # 总训练Loss + epoch_correct = 0 # 总正确样本数(用于计算平均Acc) + total_samples = 0 # 总训练样本数 + + for batch_id, (ids, datas, labels) in enumerate(train_loader): + model.train() + batch_size = labels.size(0) + total_samples += batch_size + labels = labels.to(args.device) # 移动标签到设备 + + ### 调整学习率 + iterations = epoch * len(train_loader) + batch_id + adjust_lr(iterations, optimizer, schedule) + + ### 前向传播与损失计算 + datas = datas.to(args.device) + batch_loss = 0.0 # 当前batch的总Loss + + with amp_context(): + outs = model(datas) + loss = 0. + + # 计算各部分Loss(与原有逻辑一致) + for name in outs: + if "select_" in name: + if not args.use_selection: + raise ValueError("Selector not use here.") + if args.lambda_s != 0: + S = outs[name].size(1) + logit = outs[name].view(-1, args.num_classes).contiguous() + loss_s = nn.CrossEntropyLoss()(logit, labels.unsqueeze(1).repeat(1, S).flatten(0)) + loss += args.lambda_s * loss_s + batch_loss += loss_s.item() * args.lambda_s + + elif "drop_" in name: + if not args.use_selection: + raise ValueError("Selector not use here.") + if args.lambda_n != 0: + S = outs[name].size(1) + logit = outs[name].view(-1, args.num_classes).contiguous() + n_preds = nn.Tanh()(logit) + labels_0 = torch.zeros([batch_size * S, args.num_classes]).to(args.device) - 1 + loss_n = nn.MSELoss()(n_preds, labels_0) + loss += args.lambda_n * loss_n + batch_loss += loss_n.item() * args.lambda_n + + elif "layer" in name: + if not args.use_fpn: + raise ValueError("FPN not use here.") + if args.lambda_b != 0: + loss_b = nn.CrossEntropyLoss()(outs[name].mean(1), labels) + loss += args.lambda_b * loss_b + batch_loss += loss_b.item() * args.lambda_b + + elif "comb_outs" in name: + if not args.use_combiner: + raise ValueError("Combiner not use here.") + if args.lambda_c != 0: + loss_c = nn.CrossEntropyLoss()(outs[name], labels) + loss += args.lambda_c * loss_c + batch_loss += loss_c.item() * args.lambda_c + + elif "ori_out" in name: + loss_ori = F.cross_entropy(outs[name], labels) + loss += loss_ori + batch_loss += loss_ori.item() + + # 梯度累积:还原真实Loss + loss /= args.update_freq + batch_real_loss = loss.item() * args.update_freq # 当前batch的真实Loss + epoch_total_loss += batch_real_loss + + ### 反向传播与参数更新 + if args.use_amp: + scaler.scale(loss).backward() + else: + loss.backward() + + if (batch_id + 1) % args.update_freq == 0: + if args.use_amp: + scaler.step(optimizer) + scaler.update() + else: + optimizer.step() + optimizer.zero_grad() + + ### 计算并打印当前batch的训练Acc和Loss + if (batch_id + 1) % args.log_freq == 0: + # 计算当前batch的Acc(优先取combiner,无则取原始输出) + if args.use_combiner and "comb_outs" in outs: + pred = torch.argmax(outs["comb_outs"], dim=1) + elif "ori_out" in outs: + pred = torch.argmax(outs["ori_out"], dim=1) + else: + # 若以上都没有,取最后一个FPN层的输出 + pred = torch.argmax(outs["layer4"].mean(1), dim=1) + + # 计算当前batch的正确数和Acc + batch_correct = (pred == labels).sum().item() + batch_acc = (batch_correct / batch_size) * 100 # 转换为百分比 + epoch_correct += batch_correct # 累加至总正确数 + + # 打印batch级指标 + print(f"[Train] Epoch {epoch+1:2d} | Batch {batch_id+1:4d}/{total_batchs:4d} | " + f"Loss: {batch_real_loss:.4f} | Acc: {batch_acc:.2f}%") + + ### 显示训练进度 + train_progress = (batch_id + 1) / total_batchs + if train_progress > show_progress[progress_i]: + print(".." + str(int(show_progress[progress_i] * 100)) + "%", end='', flush=True) + progress_i += 1 + + ### 打印当前epoch的训练汇总 + avg_train_loss = epoch_total_loss / total_batchs # 平均Loss(按batch数) + avg_train_acc = (epoch_correct / total_samples) * 100 # 平均Acc(按样本数) + print(f"\n[Train Summary] Epoch {epoch+1:2d} | Avg Loss: {avg_train_loss:.4f} | Avg Acc: {avg_train_acc:.2f}%") + + +def main(args, tlogger): + train_loader, val_loader, model, optimizer, schedule, scaler, amp_context, start_epoch = set_environment(args, tlogger) + + best_acc = 0.0 + best_eval_name = "null" + + if args.use_wandb: + wandb.init(entity=args.wandb_entity, project=args.project_name, name=args.exp_name, config=args) + wandb.run.summary["best_acc"] = best_acc + wandb.run.summary["best_epoch"] = 0 + + for epoch in range(start_epoch, args.max_epochs): + ### 训练阶段 + if train_loader is not None: + tlogger.print("Start Training {} Epoch".format(epoch + 1)) + train(args, epoch, model, scaler, amp_context, optimizer, schedule, train_loader) + tlogger.print() + else: + from eval import eval_and_save + eval_and_save(args, model, val_loader) + break + + ### 调整验证频率 + eval_freq_schedule(args, epoch) + + ### 保存最新模型 + model_to_save = model.module if hasattr(model, "module") else model + checkpoint = {"model": model_to_save.state_dict(), "optimizer": optimizer.state_dict(), "epoch": epoch} + torch.save(checkpoint, args.save_dir + "backup/last.pt") + + ### 验证阶段(恢复原始逻辑,不计算验证Loss) + if epoch == 0 or (epoch + 1) % args.eval_freq == 0: + acc = -1 + if val_loader is not None: + tlogger.print("Start Evaluating {} Epoch".format(epoch + 1)) + # 恢复原始evaluate调用(仅返回3个值) + acc, eval_name, accs = evaluate(args, model, val_loader) + # 打印验证Acc(不含Loss) + print(f"[Val] Epoch {epoch+1:2d} | Best Acc: {max(acc, best_acc):.2f}% (Current Acc: {acc:.2f}%)") + tlogger.print() + + ### 更新wandb日志 + if args.use_wandb: + wandb.log(accs) + + ### 更新最佳模型 + if acc > best_acc: + best_acc = acc + best_eval_name = eval_name + torch.save(checkpoint, args.save_dir + "backup/best.pt") + print(f"[Update Best Model] Epoch {epoch+1:2d} | Best Acc: {best_acc:.2f}%") + + ### 更新wandb摘要 + if args.use_wandb: + wandb.run.summary["best_acc"] = best_acc + wandb.run.summary["best_epoch"] = epoch + 1 + + +if __name__ == "__main__": + tlogger = timeLogger() + tlogger.print("Reading Config...") + args = get_args() + assert args.c != "", "Please provide config file (.yaml)" + load_yaml(args, args.c) + if not hasattr(args, "log_freq"): + args.log_freq = 10 # 默认每10个batch打印一次 + build_record_folder(args) + tlogger.print() + main(args, tlogger) diff --git a/vit_pim_main_sgd.py b/vit_pim_main_sgd.py new file mode 100644 index 0000000..1065083 --- /dev/null +++ b/vit_pim_main_sgd.py @@ -0,0 +1,428 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import contextlib +import wandb +import warnings + +from models.builder import MODEL_GETTER +from data.dataset import build_loader +from utils.costom_logger import timeLogger +from utils.config_utils import load_yaml, build_record_folder, get_args +from utils.lr_schedule import cosine_decay, adjust_lr, get_lr +from eval import evaluate, cal_train_metrics + +warnings.simplefilter("ignore") + + +def eval_freq_schedule(args, epoch: int): + """ + 根据当前训练的 epoch 调整验证频率(eval_freq)。 + 在训练接近尾声时更频繁地进行验证,以便更好地监控模型性能。 + + 参数: + args: 包含训练配置参数的对象,其中包括 eval_freq 和 max_epochs。 + epoch: 当前训练的 epoch 数。 + """ + # 如果当前 epoch 大于等于最大训练轮次的 95%,则将验证频率设为 1(每个 epoch 都验证) + if epoch >= args.max_epochs * 0.95: + args.eval_freq = 1 + # 如果当前 epoch 大于等于最大训练轮次的 90% 但小于 95%,同样将验证频率设为 1 + elif epoch >= args.max_epochs * 0.9: + args.eval_freq = 1 + # 如果当前 epoch 大于等于最大训练轮次的 80% 但小于 90%,将验证频率设为 2(每两个 epoch 验证一次) + elif epoch >= args.max_epochs * 0.8: + args.eval_freq = 2 + + +def set_environment(args, tlogger): + """ + 设置训练环境,包括设备、数据加载器、模型、优化器等。 + + 参数: + args: 包含训练配置参数的对象。 + tlogger: 用于记录时间日志的对象。 + + 返回: + train_loader: 训练数据加载器。 + val_loader: 验证数据加载器。 + model: 构建并初始化的模型。 + optimizer: 优化器(如果仅评估则为None)。 + schedule: 学习率调度器(如果仅评估则为None)。 + scaler: AMP缩放器(如果不使用AMP则为None)。 + amp_context: AMP上下文管理器(如果不使用AMP则是nullcontext)。 + start_epoch: 训练开始的epoch数(如果有预训练模型,则从该模型的epoch开始)。 + """ + + print("Setting Environment...") + + # 设置训练设备:如果CUDA可用则使用GPU,否则使用CPU + args.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + + ### = = = = Dataset and Data Loader = = = = + # 构建训练和验证数据加载器 + tlogger.print("Building Dataloader....") + + train_loader, val_loader = build_loader(args) + + # 检查是否成功构建了数据加载器 + if train_loader is None and val_loader is None: + raise ValueError("Find nothing to train or evaluate.") + + # 打印训练集信息 + if train_loader is not None: + print(" Train Samples: {} (batch: {})".format(len(train_loader.dataset), len(train_loader))) + else: + # raise ValueError("Build train loader fail, please provide legal path.") + print(" Train Samples: 0 ~~~~~> [Only Evaluation]") + + # 打印验证集信息 + if val_loader is not None: + print(" Validation Samples: {} (batch: {})".format(len(val_loader.dataset), len(val_loader))) + else: + print(" Validation Samples: 0 ~~~~~> [Only Training]") + tlogger.print() + + ### = = = = Model = = = = + # 构建模型 + tlogger.print("Building Model....") + model = MODEL_GETTER[args.model_name]( + use_fpn=args.use_fpn, + fpn_size=args.fpn_size, + use_selection=args.use_selection, + num_classes=args.num_classes, + num_selects=args.num_selects, + use_combiner=args.use_combiner, + ) # about return_nodes, we use our default setting + + # 如果提供了预训练模型,则加载权重 + if args.pretrained is not None: + checkpoint = torch.load(args.pretrained, map_location=torch.device('cpu')) + model.load_state_dict(checkpoint['model']) + start_epoch = checkpoint['epoch'] + print(start_epoch) + else: + start_epoch = 0 + + # 将模型移动到指定设备 + model.to(args.device) + tlogger.print() + + """ + 如果你有多GPU设备,可以在单机多GPU情况下使用torch.nn.DataParallel, + 或者使用torch.nn.parallel.DistributedDataParallel实现多进程并行。 + 更多详情:https://pytorch.org/tutorials/beginner/dist_overview.html + """ + + # 如果没有训练数据加载器,只进行评估,返回部分对象 + if train_loader is None: + return train_loader, val_loader, model, None, None, None, None, start_epoch + + ### = = = = Optimizer = = = = + # 构建优化器 + tlogger.print("Building Optimizer....") + if args.optimizer == "SGD": + optimizer = torch.optim.SGD(model.parameters(), lr=args.max_lr, nesterov=True, momentum=0.9, + weight_decay=args.wdecay) + elif args.optimizer == "AdamW": + optimizer = torch.optim.AdamW(model.parameters(), lr=args.max_lr) + + # 如果有预训练模型,加载优化器状态 + if args.pretrained is not None: + optimizer.load_state_dict(checkpoint['optimizer']) + + tlogger.print() + + # 构建学习率调度器 + schedule = cosine_decay(args, len(train_loader)) + + # 如果使用混合精度训练(AMP),设置相关的组件 + if args.use_amp: + scaler = torch.cuda.amp.GradScaler() + amp_context = torch.cuda.amp.autocast + else: + scaler = None + amp_context = contextlib.nullcontext + + # 返回所有构建的组件 + return train_loader, val_loader, model, optimizer, schedule, scaler, amp_context, start_epoch + + +def train(args, epoch, model, scaler, amp_context, optimizer, schedule, train_loader): + """ + 训练函数,在每个epoch中迭代训练数据并更新模型参数。 + + 参数: + args: 包含训练配置参数的对象。 + epoch: 当前训练的 epoch 数。 + model: 要训练的模型。 + scaler: AMP缩放器(如果不使用AMP则为None)。 + amp_context: AMP上下文管理器(如果不使用AMP则是nullcontext)。 + optimizer: 优化器。 + schedule: 学习率调度器。 + train_loader: 训练数据加载器。 + """ + + # 清空优化器的梯度 + optimizer.zero_grad() + + # 获取总批次数,仅用于日志记录 + total_batchs = len(train_loader) + + # 定义训练进度显示点(0%, 10%, ..., 100%) + show_progress = [x / 10 for x in range(11)] + progress_i = 0 + + # 遍历训练数据加载器中的每个批次 + for batch_id, (ids, datas, labels) in enumerate(train_loader): + # 设置模型为训练模式 + model.train() + + """ = = = = adjust learning rate = = = = """ + # 计算当前迭代次数 + iterations = epoch * len(train_loader) + batch_id + # 调整学习率 + adjust_lr(iterations, optimizer, schedule) + + # 获取当前批次的样本数量 + batch_size = labels.size(0) + + """ = = = = forward and calculate loss = = = = """ + # 将数据和标签移动到指定设备 + datas, labels = datas.to(args.device), labels.to(args.device) + + # 使用AMP上下文进行前向传播(如果启用AMP) + with amp_context(): + """ + [Model Return] + FPN + Selector + Combiner --> return 'layer1', 'layer2', 'layer3', 'layer4', ...(depend on your setting) + 'preds_0', 'preds_1', 'comb_outs' + FPN + Selector --> return 'layer1', 'layer2', 'layer3', 'layer4', ...(depend on your setting) + 'preds_0', 'preds_1' + FPN --> return 'layer1', 'layer2', 'layer3', 'layer4' (depend on your setting) + ~ --> return 'ori_out' + + [Retuen Tensor] + 'preds_0': logit has not been selected by Selector. + 'preds_1': logit has been selected by Selector. + 'comb_outs': The prediction of combiner. + """ + # 前向传播获取输出 + outs = model(datas) + + # 初始化总损失 + loss = 0. + + # 遍历模型输出的各个部分,计算相应的损失 + for name in outs: + # 处理选择器的输出 + if "select_" in name: + if not args.use_selection: + raise ValueError("Selector not use here.") + if args.lambda_s != 0: + # 计算选择器损失 + S = outs[name].size(1) + logit = outs[name].view(-1, args.num_classes).contiguous() + loss_s = nn.CrossEntropyLoss()(logit, + labels.unsqueeze(1).repeat(1, S).flatten(0)) + loss += args.lambda_s * loss_s + else: + loss_s = 0.0 + + # 处理丢弃部分的输出 + elif "drop_" in name: + if not args.use_selection: + raise ValueError("Selector not use here.") + + if args.lambda_n != 0: + # 计算负样本损失 + S = outs[name].size(1) + logit = outs[name].view(-1, args.num_classes).contiguous() + n_preds = nn.Tanh()(logit) + labels_0 = torch.zeros([batch_size * S, args.num_classes]) - 1 + labels_0 = labels_0.to(args.device) + loss_n = nn.MSELoss()(n_preds, labels_0) + loss += args.lambda_n * loss_n + else: + loss_n = 0.0 + + # 处理FPN层的输出 + elif "layer" in name: + if not args.use_fpn: + raise ValueError("FPN not use here.") + if args.lambda_b != 0: + # 计算FPN基础损失 + ### here using 'layer1'~'layer4' is default setting, you can change to your own + loss_b = nn.CrossEntropyLoss()(outs[name].mean(1), labels) + loss += args.lambda_b * loss_b + else: + loss_b = 0.0 + + # 处理组合器的输出 + elif "comb_outs" in name: + if not args.use_combiner: + raise ValueError("Combiner not use here.") + + if args.lambda_c != 0: + # 计算组合器损失 + loss_c = nn.CrossEntropyLoss()(outs[name], labels) + loss += args.lambda_c * loss_c + + # 处理原始输出 + elif "ori_out" in name: + # 计算原始输出损失 + loss_ori = F.cross_entropy(outs[name], labels) + loss += loss_ori + + # 对损失进行平均化处理 + loss /= args.update_freq + + """ = = = = calculate gradient = = = = """ + # 计算梯度(根据是否使用AMP选择不同的方式) + if args.use_amp: + scaler.scale(loss).backward() + else: + loss.backward() + + """ = = = = update model = = = = """ + # 更新模型参数(每隔update_freq个批次更新一次) + if (batch_id + 1) % args.update_freq == 0: + if args.use_amp: + # 使用AMP更新模型 + scaler.step(optimizer) + scaler.update() # next batch + else: + # 正常更新模型 + optimizer.step() + # 清空梯度 + optimizer.zero_grad() + + """ log (MISC) """ + # 记录训练日志(如果启用wandb且达到记录频率) + if args.use_wandb and ((batch_id + 1) % args.log_freq == 0): + # 切换到评估模式进行日志记录 + model.eval() + msg = {} + msg['info/epoch'] = epoch + 1 + msg['info/lr'] = get_lr(optimizer) + # 计算并记录训练指标 + cal_train_metrics(args, msg, outs, labels, batch_size) + # 将日志信息发送到wandb + wandb.log(msg) + + # 显示训练进度 + train_progress = (batch_id + 1) / total_batchs + # print(train_progress, show_progress[progress_i]) + if train_progress > show_progress[progress_i]: + print(".." + str(int(show_progress[progress_i] * 100)) + "%", end='', flush=True) + progress_i += 1 + +def main(args, tlogger): + """ + 主训练循环函数,负责整个训练和验证过程,包括模型保存(last.pt 和 best.pt)。 + + 参数: + args: 包含训练配置参数的对象。 + tlogger: 用于记录时间日志的对象。 + """ + + # 调用set_environment函数设置训练环境,获取数据加载器、模型、优化器等 + train_loader, val_loader, model, optimizer, schedule, scaler, amp_context, start_epoch = set_environment(args, + tlogger) + + # 初始化最佳准确率和最佳评估名称 + best_acc = 0.0 + best_eval_name = "null" + + # 如果启用wandb,则初始化wandb项目并设置初始摘要信息 + if args.use_wandb: + wandb.init(entity=args.wandb_entity, + project=args.project_name, + name=args.exp_name, + config=args) + wandb.run.summary["best_acc"] = best_acc + wandb.run.summary["best_eval_name"] = best_eval_name + wandb.run.summary["best_epoch"] = 0 + + # 开始训练循环,从start_epoch到max_epochs + for epoch in range(start_epoch, args.max_epochs): + + """ + 训练阶段 + """ + # 如果存在训练数据加载器,则进行训练 + if train_loader is not None: + tlogger.print("Start Training {} Epoch".format(epoch + 1)) + # 调用train函数进行一个epoch的训练 + train(args, epoch, model, scaler, amp_context, optimizer, schedule, train_loader) + tlogger.print() + else: + # 如果没有训练数据加载器(仅评估模式),则调用eval_and_save进行评估并保存结果,然后退出循环 + from eval import eval_and_save + eval_and_save(args, model, val_loader) + break + + # 根据当前epoch调整验证频率 + eval_freq_schedule(args, epoch) + + # 准备要保存的模型检查点(处理多GPU情况) + model_to_save = model.module if hasattr(model, "module") else model + checkpoint = {"model": model_to_save.state_dict(), "optimizer": optimizer.state_dict(), "epoch": epoch} + # 保存最新的模型检查点 + torch.save(checkpoint, args.save_dir + "backup/last.pt") + + # 根据评估频率进行验证(每个epoch或每隔几个epoch) + if epoch == 0 or (epoch + 1) % args.eval_freq == 0: + """ + 验证阶段 + """ + acc = -1 + # 如果存在验证数据加载器,则进行验证 + if val_loader is not None: + tlogger.print("Start Evaluating {} Epoch".format(epoch + 1)) + # 调用evaluate函数进行验证,获取准确率等信息 + acc, eval_name, accs = evaluate(args, model, val_loader) + # 打印当前验证结果和历史最佳准确率 + tlogger.print("....BEST_ACC: {}% ({}%)".format(max(acc, best_acc), acc)) + tlogger.print() + + # 如果启用wandb,则记录验证指标 + if args.use_wandb: + wandb.log(accs) + + # 如果当前准确率优于历史最佳准确率,则更新最佳准确率并保存最佳模型 + if acc > best_acc: + best_acc = acc + best_eval_name = eval_name + torch.save(checkpoint, args.save_dir + "backup/best.pt") + # 如果启用wandb,则更新wandb中的最佳指标摘要 + if args.use_wandb: + wandb.run.summary["best_acc"] = best_acc + wandb.run.summary["best_eval_name"] = best_eval_name + wandb.run.summary["best_epoch"] = epoch + 1 + + +if __name__ == "__main__": + # 创建一个时间记录器实例,用于记录和打印时间相关的日志 + tlogger = timeLogger() + + # 打印正在读取配置文件的信息 + tlogger.print("Reading Config...") + + # 获取命令行参数,这些参数包括配置文件路径等 + args = get_args() + + # 断言确保提供了配置文件(.yaml格式),如果没有提供则抛出错误信息 + assert args.c != "", "Please provide config file (.yaml)" + + # 加载指定的YAML配置文件,并将配置内容存入args对象中 + load_yaml(args, args.c) + + # 根据配置创建记录文件夹,用于保存训练过程中的日志、模型等文件 + build_record_folder(args) + + # 打印空行,起到分隔日志的作用 + tlogger.print() + + # 调用main函数开始执行主要的训练或评估流程,传入解析后的参数和时间记录器 + main(args, tlogger) \ No newline at end of file