forked from JEF1056/Jade_T5
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathexport.py
More file actions
54 lines (47 loc) · 1.91 KB
/
Copy pathexport.py
File metadata and controls
54 lines (47 loc) · 1.91 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
import t5
import os
import argparse
import src.createtask
import warnings
import logging as py_logging
warnings.filterwarnings("ignore", category=DeprecationWarning)
py_logging.root.setLevel('INFO')
parser = argparse.ArgumentParser(description='Export checkpoints for serving')
parser.add_argument('-dir', type=str, required=True,
help='Directory of model checkpoints (can/should be a gs:// link)')
parser.add_argument('-out', type=str, default=None,
help='Directory to save output')
parser.add_argument('-name', type=str, default=None,
help='Directory to save output')
parser.add_argument('-size', type=str, default="small",
help='an integer for the accumulator')
parser.add_argument('-temperature', type=float, default=0.9,
help='model temperature')
parser.add_argument('-beams', type=int, default=1,
help='model temperature')
args = parser.parse_args()
# Set parallelism and batch size to fit on v2-8 TPU (if possible).
# Limit number of checkpoints to fit within 5GB (if possible).
model_parallelism, train_batch_size = {
"small": (1, 256),
"base": (2, 128),
"large": (8, 64),
"3B": (8, 16),
"11B": (8, 16)}[args.size]
model = t5.models.MtfModel(
tpu=False,
model_dir=args.dir,
model_parallelism=model_parallelism,
batch_size=train_batch_size,
)
print("~~Exporting~~")
export_dir = os.path.join(args.dir, "export") if args.out == None else args.out
model.batch_size = 1 # make one prediction per call
saved_model_path = model.export(
args.out,
checkpoint_step=-1, # use most recent
beam_size=args.beams, # no beam search
temperature=args.temperature, # sample according to predicted distribution
)
os.rename(saved_model_path, os.path.join(args.out, args.name))
print("Model saved to:", os.path.join(args.out, args.name))