mrT23 / TResNet

TOP 1 ACCURACY TOP 5 ACCURACY
MODEL CODE PAPER
ε-REPR
CODE PAPER
ε-REPR
PAPER
GLOBAL RANK
TResNet-M
(224-Mean-Max-vazoom)
80.0% 80.7% 94.4% 94.8% #300
See Full Build Details +get badge code
[![SotaBench](https://img.shields.io/endpoint.svg?url=https://sotabench.com/api/v0/badge/gh/Randl/TResNet)](https://sotabench.com/user/EvgeniiZh/repos/Randl/TResNet)

How the Repository is Evaluated

The full sotabench.py file - source
import argparse
import gc

import torch
from torchbench.image_classification import ImageNet
from torchvision.transforms import transforms

from src.models import create_model
from src.models.tresnet.layers.avg_pool import TestTimePoolHead
from src.models.tresnet.tresnet import InplacABN_to_ABN
from src.models.utils.fuse_bn import fuse_bn_recursively

parser = argparse.ArgumentParser(description='PyTorch TResNet ImageNet Inference')
parser.add_argument('--val_dir')
parser.add_argument('--model_path')
parser.add_argument('--model_name', type=str, default='tresnet_m')
parser.add_argument('--num_classes', type=int, default=1000)
parser.add_argument('--input_size', type=int, default=224)
parser.add_argument('--val_zoom_factor', type=int, default=0.875)
parser.add_argument('--batch_size', type=int, default=48)
parser.add_argument('--num_workers', type=int, default=8)
parser.add_argument('--remove_aa_jit', action='store_true', default=False)

# parsing args
args = parser.parse_args()

# MTResNet 224-Mean-Max-vazoom

val_bs = args.batch_size
val_tfms = transforms.Compose(
    [transforms.Resize(int(224 / args.val_zoom_factor)),
     transforms.CenterCrop(224)])
val_tfms.transforms.append(transforms.ToTensor())

model_path = './tresnet_m.pth'
model = create_model(args)
state = torch.load(model_path, map_location='cpu')['model']
model.load_state_dict(state, strict=True)

model = TestTimePoolHead(model, 5)

model = InplacABN_to_ABN(model)
model = fuse_bn_recursively(model)
model = model.cuda()
model.eval()
print('Benchmarking TResNet-M (224-Mean-Max-vazoom)')

# Run the benchmark
ImageNet.benchmark(
    model=model,
    paper_model_name='TResNet-M (224-Mean-Max-vazoom)',
    paper_arxiv_id='2003.13630',
    input_transform=val_tfms,
    batch_size=432,
    num_workers=args.num_workers,
    num_gpu=1,
    pin_memory=True,
    paper_results={'Top 1 Accuracy': 0.807, 'Top 5 Accuracy': 0.948},
    model_description="Official weights from the author's of the paper."
)

del model
gc.collect()
torch.cuda.empty_cache()
STATUS
BUILD
COMMIT MESSAGE
RUN TIME
Remove duplicates (https://github.com/paperswithcode/sotabench-a…
Randl   669afd4  ·  May 04 2020
0h:16m:36s
Add missing parameter
Randl   c4a80ee  ·  Apr 23 2020
0h:14m:35s
0h:09m:02s