99from typing import Any , Dict , Optional , Tuple
1010
1111import torch
12- import torchvision .models as models
1312from torchtnt .utils .module_summary import (
1413 _get_human_readable_count ,
1514 get_module_summary ,
1615 ModuleSummary ,
1716 prune_module_summary ,
1817)
18+ from torchvision import models
1919
2020
2121def get_summary_and_prune (
@@ -155,7 +155,9 @@ def test_lazy_tensor_flops(self) -> None:
155155
156156 def test_resnet_max_depth (self ) -> None :
157157 """Test the behavior of max_depth on a layered model like ResNet"""
158- pretrained_model = models .resnet .resnet18 (pretrained = True )
158+ pretrained_model = models .resnet .resnet18 (
159+ weights = models .ResNet18_Weights .IMAGENET1K_V1 # pyre-ignore[16]
160+ )
159161
160162 # max_depth = None
161163 ms1 = get_module_summary (pretrained_model )
@@ -211,7 +213,9 @@ def test_module_summary_layer_print(self) -> None:
211213 self ._test_module_summary_text (summary_table , str (ms1 ))
212214
213215 def test_alexnet_print (self ) -> None :
214- pretrained_model = models .alexnet (pretrained = True )
216+ pretrained_model = models .alexnet (
217+ weights = models .AlexNet_Weights .IMAGENET1K_V1 # pyre-ignore[16]
218+ )
215219 ms1 = get_summary_and_prune (pretrained_model , max_depth = 1 )
216220 ms2 = get_summary_and_prune (pretrained_model , max_depth = 2 )
217221 ms3 = get_summary_and_prune (pretrained_model , max_depth = 3 )
@@ -236,7 +240,9 @@ def test_alexnet_print(self) -> None:
236240 self .assertEqual (str (ms3 ), str (ms4 ))
237241
238242 def test_alexnet_with_input_tensor (self ) -> None :
239- pretrained_model = models .alexnet (pretrained = True )
243+ pretrained_model = models .alexnet (
244+ weights = models .AlexNet_Weights .IMAGENET1K_V1 # pyre-ignore[16]
245+ )
240246 inp = torch .randn (1 , 3 , 224 , 224 )
241247 ms1 = get_summary_and_prune (pretrained_model , max_depth = 1 , module_args = (inp ,))
242248 ms2 = get_summary_and_prune (pretrained_model , max_depth = 2 , module_args = (inp ,))
@@ -344,7 +350,9 @@ def forward(self, x, y, offset=1):
344350 self .assertEqual (ms_classifier .out_size , [1 , 1 , 224 , 224 ])
345351
346352 def test_forward_elapsed_time (self ) -> None :
347- pretrained_model = models .alexnet (pretrained = True )
353+ pretrained_model = models .alexnet (
354+ weights = models .AlexNet_Weights .IMAGENET1K_V1 # pyre-ignore[16]
355+ )
348356 inp = torch .randn (1 , 3 , 224 , 224 )
349357 ms1 = get_summary_and_prune (pretrained_model , module_args = (inp ,), max_depth = 4 )
350358 stack = [ms1 ] + [
0 commit comments