Skip to content

Commit 2feccb3

Browse files
还原删除代码
1 parent 8b94b3e commit 2feccb3

1 file changed

Lines changed: 58 additions & 0 deletions

File tree

src/py3.x/tensorflow2.x/text_bert.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,64 @@ def seq_padding(X, padding=0):
9393
## bert / Embedding/ + lstm + crt
9494

9595

96+
#%%
97+
# 加载数据
98+
class TextBert():
99+
def __init__(self):
100+
self.path_config = Config.bert.path_config
101+
self.path_checkpoint = Config.bert.path_checkpoint
102+
103+
self.token_dict = {}
104+
with codecs.open(Config.bert.dict_path, 'r', 'utf8') as reader:
105+
for line in reader:
106+
token = line.strip()
107+
self.token_dict[token] = len(self.token_dict)
108+
109+
110+
def prepare_data(self):
111+
neg = pd.read_excel(Config.bert.path_neg, header=None)
112+
pos = pd.read_excel(Config.bert.path_pos, header=None)
113+
data = []
114+
for d in neg[0]:
115+
data.append((d, 0))
116+
for d in pos[0]:
117+
data.append((d, 1))
118+
# 按照9:1的比例划分训练集和验证集
119+
random_order = list(range(len(data)))
120+
np.random.shuffle(random_order)
121+
train_data = [data[j] for i, j in enumerate(random_order) if i % 10 != 0]
122+
valid_data = [data[j] for i, j in enumerate(random_order) if i % 10 == 0]
123+
return train_data, valid_data
124+
125+
def build_model(self, m_type="bert"):
126+
if m_type == "bert":
127+
bert_model = load_trained_model_from_checkpoint(self.path_config, self.path_checkpoint, seq_len=None)
128+
for l in bert_model.layers:
129+
l.trainable = True
130+
x1_in = Input(shape=(None,))
131+
x2_in = Input(shape=(None,))
132+
x = bert_model([x1_in, x2_in])
133+
x = Lambda(lambda x: x[:, 0])(x)
134+
p = Dense(1, activation='sigmoid')(x)#根据分类种类自行调节,也可以多加一些层数
135+
model = Model([x1_in, x2_in], p)
136+
model.compile(
137+
loss='binary_crossentropy',
138+
optimizer=Adam(1e-5), # 用足够小的学习率
139+
metrics=['accuracy']
140+
)
141+
else:
142+
# 否则用 Embedding
143+
model = Sequential()
144+
model.add(Embedding(len(vocab), EMBED_DIM, mask_zero=True)) # Random embedding
145+
model.add(Bidirectional(LSTM(BiRNN_UNITS // 2, return_sequences=True)))
146+
crf = CRF(len(chunk_tags), sparse_target=True)
147+
model.add(crf)
148+
model.compile('adam', loss=crf.loss_function, metrics=[crf.accuracy])
149+
150+
model.summary()
151+
return model
152+
153+
96154
#%%
97155
# 加载数据
98156
from keras_bert import Tokenizer

0 commit comments

Comments
 (0)