Skip to content

Commit ce3cd65

Browse files
committed
new train
1 parent de4e74e commit ce3cd65

2 files changed

Lines changed: 15 additions & 2 deletions

File tree

export.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import warnings
55
import logging as py_logging
66
from src.createtask import create_registry
7-
create_registry(None, "src/temp.txt", "src/temp.txt", "all_mix", None)
7+
create_registry(None, "src/temp.txt", "src/temp.txt", "all_mix", None, "local")
88
warnings.filterwarnings("ignore", category=DeprecationWarning)
99
py_logging.root.setLevel('INFO')
1010

train.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,15 @@
1313
help='link to google storage bucket')
1414
parser.add_argument('-tasknames', nargs='+', type=str, default="all-mix",
1515
help='name of the task')
16+
parser.add_argument('-gpus', nargs='+', type=str, default=None,
17+
help='available GPUs')
1618
parser.add_argument('-train', nargs='+', type=str, default="context-train.txt",
1719
help='train file')
1820
parser.add_argument('-val', nargs='+', type=str, default="context-val.txt",
1921
help='val file')
2022
parser.add_argument('-tpu_address', type=str, default=None,
2123
help='TPU ip address')
22-
parser.add_argument('-tpu_topology', type=str, default="v3-8", choices=["v2-8","v3-8"],
24+
parser.add_argument('-tpu_topology', type=str, default=None, choices=["v2-8","v3-8", None],
2325
help='train file')
2426
parser.add_argument('-in_len', type=int, default=2048,
2527
help='train file')
@@ -31,8 +33,14 @@
3133
help='train file')
3234
parser.add_argument('-compression', type=str, default=None, choices=[None, "ZLIB", "GZIP"],
3335
help='compression the dataset is compressed with')
36+
parser.add_argument('-batch_size', type=int, default=None,
37+
help='number of batches')
38+
parser.add_argument('-max_checkpoints', type=int, default=None,
39+
help='maximum number of checkpoints')
3440
parser.add_argument('-storemode', type=str, default="gs", choices=["gs", "local"],
3541
help='storemode')
42+
parser.add_argument('-paralellism', type=int, default=None,
43+
help='model_paralellism')
3644
args = parser.parse_args()
3745

3846
from src.createtask import create_registry
@@ -85,12 +93,17 @@ def tf_verbosity_level(level):
8593
"large": (4, 128, 2),
8694
"3B": (8, 16, 1),
8795
"11B": (8, 4, 1)}[MODEL_SIZE]
96+
if args.paralellism: model_paralellism=args.paralellism
97+
if args.batch_size: train_batch_size=args.batch_size
98+
if args.max_checkpoints: keep_checkpoint_max=args.max_checkpoints
8899

89100
tf.io.gfile.makedirs(MODEL_DIR)
90101
# The models from our paper are based on the Mesh Tensorflow Transformer.
91102
model = t5.models.MtfModel(
92103
model_dir=MODEL_DIR,
93104
tpu=args.tpu_address,
105+
mesh_devices=args.gpus,
106+
mesh_shape=f'model:1,batch:{len(args.gpus)}',
94107
tpu_topology=args.tpu_topology,
95108
model_parallelism=model_parallelism,
96109
batch_size=train_batch_size,

0 commit comments

Comments
 (0)