Skip to content

Commit 419825f

Browse files
kit1980facebook-github-bot
authored andcommitted
Add TorchFix linter to CI (meta-pytorch#569)
Summary: Add TorchFix (https://github.com/pytorch/test-infra/tree/main/tools/torchfix) flake8 plugin. Pull Request resolved: meta-pytorch#569 Test Plan: Verify that buck2 test torchtnt/tests/utils:test_module_summary passes. Reviewed By: galrotem Differential Revision: D50438385 Pulled By: kit1980 fbshipit-source-id: 68d7d07d2ff3e4374ea5e5e91896720b85671ba6
1 parent c88c7a6 commit 419825f

3 files changed

Lines changed: 16 additions & 6 deletions

File tree

.flake8

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[flake8]
22
# Suggested config from pytorch that we can adopt
3-
select = B,C,E,F,P,T4,W,B9
3+
select = B,C,E,F,P,T4,W,B9,TOR0,TOR1,TOR2
44
max-line-length = 120
55
# C408 ignored because we like the dict keyword argument syntax
66
# E501 is not flexible enough, we're using B950 instead

.pre-commit-config.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@ repos:
2727
- id: flake8
2828
args:
2929
- --config=.flake8
30+
additional_dependencies:
31+
- torchfix==0.1.1
3032

3133
- repo: https://github.com/omnilib/ufmt
3234
rev: v1.3.0

tests/utils/test_module_summary.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,13 @@
99
from typing import Any, Dict, Optional, Tuple
1010

1111
import torch
12-
import torchvision.models as models
1312
from torchtnt.utils.module_summary import (
1413
_get_human_readable_count,
1514
get_module_summary,
1615
ModuleSummary,
1716
prune_module_summary,
1817
)
18+
from torchvision import models
1919

2020

2121
def get_summary_and_prune(
@@ -155,7 +155,9 @@ def test_lazy_tensor_flops(self) -> None:
155155

156156
def test_resnet_max_depth(self) -> None:
157157
"""Test the behavior of max_depth on a layered model like ResNet"""
158-
pretrained_model = models.resnet.resnet18(pretrained=True)
158+
pretrained_model = models.resnet.resnet18(
159+
weights=models.ResNet18_Weights.IMAGENET1K_V1 # pyre-ignore[16]
160+
)
159161

160162
# max_depth = None
161163
ms1 = get_module_summary(pretrained_model)
@@ -211,7 +213,9 @@ def test_module_summary_layer_print(self) -> None:
211213
self._test_module_summary_text(summary_table, str(ms1))
212214

213215
def test_alexnet_print(self) -> None:
214-
pretrained_model = models.alexnet(pretrained=True)
216+
pretrained_model = models.alexnet(
217+
weights=models.AlexNet_Weights.IMAGENET1K_V1 # pyre-ignore[16]
218+
)
215219
ms1 = get_summary_and_prune(pretrained_model, max_depth=1)
216220
ms2 = get_summary_and_prune(pretrained_model, max_depth=2)
217221
ms3 = get_summary_and_prune(pretrained_model, max_depth=3)
@@ -236,7 +240,9 @@ def test_alexnet_print(self) -> None:
236240
self.assertEqual(str(ms3), str(ms4))
237241

238242
def test_alexnet_with_input_tensor(self) -> None:
239-
pretrained_model = models.alexnet(pretrained=True)
243+
pretrained_model = models.alexnet(
244+
weights=models.AlexNet_Weights.IMAGENET1K_V1 # pyre-ignore[16]
245+
)
240246
inp = torch.randn(1, 3, 224, 224)
241247
ms1 = get_summary_and_prune(pretrained_model, max_depth=1, module_args=(inp,))
242248
ms2 = get_summary_and_prune(pretrained_model, max_depth=2, module_args=(inp,))
@@ -344,7 +350,9 @@ def forward(self, x, y, offset=1):
344350
self.assertEqual(ms_classifier.out_size, [1, 1, 224, 224])
345351

346352
def test_forward_elapsed_time(self) -> None:
347-
pretrained_model = models.alexnet(pretrained=True)
353+
pretrained_model = models.alexnet(
354+
weights=models.AlexNet_Weights.IMAGENET1K_V1 # pyre-ignore[16]
355+
)
348356
inp = torch.randn(1, 3, 224, 224)
349357
ms1 = get_summary_and_prune(pretrained_model, module_args=(inp,), max_depth=4)
350358
stack = [ms1] + [

0 commit comments

Comments
 (0)