|
13 | 13 | help='link to google storage bucket') |
14 | 14 | parser.add_argument('-tasknames', nargs='+', type=str, default="all-mix", |
15 | 15 | help='name of the task') |
| 16 | +parser.add_argument('-gpus', nargs='+', type=str, default=None, |
| 17 | + help='available GPUs') |
16 | 18 | parser.add_argument('-train', nargs='+', type=str, default="context-train.txt", |
17 | 19 | help='train file') |
18 | 20 | parser.add_argument('-val', nargs='+', type=str, default="context-val.txt", |
19 | 21 | help='val file') |
20 | 22 | parser.add_argument('-tpu_address', type=str, default=None, |
21 | 23 | help='TPU ip address') |
22 | | -parser.add_argument('-tpu_topology', type=str, default="v3-8", choices=["v2-8","v3-8"], |
| 24 | +parser.add_argument('-tpu_topology', type=str, default=None, choices=["v2-8","v3-8", None], |
23 | 25 | help='train file') |
24 | 26 | parser.add_argument('-in_len', type=int, default=2048, |
25 | 27 | help='train file') |
|
31 | 33 | help='train file') |
32 | 34 | parser.add_argument('-compression', type=str, default=None, choices=[None, "ZLIB", "GZIP"], |
33 | 35 | help='compression the dataset is compressed with') |
| 36 | +parser.add_argument('-batch_size', type=int, default=None, |
| 37 | + help='number of batches') |
| 38 | +parser.add_argument('-max_checkpoints', type=int, default=None, |
| 39 | + help='maximum number of checkpoints') |
34 | 40 | parser.add_argument('-storemode', type=str, default="gs", choices=["gs", "local"], |
35 | 41 | help='storemode') |
| 42 | +parser.add_argument('-paralellism', type=int, default=None, |
| 43 | + help='model_paralellism') |
36 | 44 | args = parser.parse_args() |
37 | 45 |
|
38 | 46 | from src.createtask import create_registry |
@@ -85,12 +93,17 @@ def tf_verbosity_level(level): |
85 | 93 | "large": (4, 128, 2), |
86 | 94 | "3B": (8, 16, 1), |
87 | 95 | "11B": (8, 4, 1)}[MODEL_SIZE] |
| 96 | +if args.paralellism: model_paralellism=args.paralellism |
| 97 | +if args.batch_size: train_batch_size=args.batch_size |
| 98 | +if args.max_checkpoints: keep_checkpoint_max=args.max_checkpoints |
88 | 99 |
|
89 | 100 | tf.io.gfile.makedirs(MODEL_DIR) |
90 | 101 | # The models from our paper are based on the Mesh Tensorflow Transformer. |
91 | 102 | model = t5.models.MtfModel( |
92 | 103 | model_dir=MODEL_DIR, |
93 | 104 | tpu=args.tpu_address, |
| 105 | + mesh_devices=args.gpus, |
| 106 | + mesh_shape=f'model:1,batch:{len(args.gpus)}', |
94 | 107 | tpu_topology=args.tpu_topology, |
95 | 108 | model_parallelism=model_parallelism, |
96 | 109 | batch_size=train_batch_size, |
|
0 commit comments