forked from JEF1056/Jade_T5
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathcreatetask.py
More file actions
59 lines (53 loc) · 2.07 KB
/
Copy pathcreatetask.py
File metadata and controls
59 lines (53 loc) · 2.07 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
import t5
import json
import functools
import numpy as np
import tensorflow.compat.v1 as tf
import tensorflow_datasets as tfds
with open("config.json", "r") as f:
nq_tsv_path=json.load(f)
def nq_dataset_fn(split, shuffle_files=False):
# We only have one file for each split.
del shuffle_files
# Load lines from the text file as examples.
ds = tf.data.TextLineDataset(nq_tsv_path[split])
#ds.shard(8, 8)
#ds.shuffle(buffer_size=10000)
# Split each "<question>\t<answer>" example into (question, answer) tuple.
ds = ds.map(
functools.partial(tf.io.decode_csv, record_defaults=["", ""],
field_delim="\t", use_quote_delim=False),
num_parallel_calls=tf.data.experimental.AUTOTUNE)
ds = ds.map(lambda *ex: dict(zip(["question", "answer"], ex)))
return ds
print("A few raw validation examples...")
for ex in tfds.as_numpy(nq_dataset_fn("validation").take(5)):
print(ex)
def preprocess(ds):
def normalize_text(text):
#print(f"trying {text}")
#text=tf.strings.unicode_encode(text, "UTF-8")
return text
def to_inputs_and_targets(ex):
"""Map {"question": ..., "answer": ...}->{"inputs": ..., "targets": ...}."""
return {
"inputs":
tf.strings.join(
["Input: ", normalize_text(ex["question"])]),
"targets": normalize_text(ex["answer"])
}
return ds.map(to_inputs_and_targets, num_parallel_calls=tf.data.experimental.AUTOTUNE)
t5.data.TaskRegistry.add(
nq_tsv_path["taskname"],
# Specify the task type.
t5.data.Task,
# Supply a function which returns a tf.data.Dataset.
dataset_fn=nq_dataset_fn,
splits=["train", "validation"],
# Supply a function which preprocesses text from the tf.data.Dataset.
text_preprocessor=[preprocess],
# Lowercase targets before computing metrics.
postprocess_fn=t5.data.postprocessors.lower_text,
# We'll use accuracy as our evaluation metric.
metric_fns=[t5.evaluation.metrics.accuracy]
)