@@ -93,6 +93,64 @@ def seq_padding(X, padding=0):
9393## bert / Embedding/ + lstm + crt
9494
9595
96+ #%%
97+ # 加载数据
98+ class TextBert ():
99+ def __init__ (self ):
100+ self .path_config = Config .bert .path_config
101+ self .path_checkpoint = Config .bert .path_checkpoint
102+
103+ self .token_dict = {}
104+ with codecs .open (Config .bert .dict_path , 'r' , 'utf8' ) as reader :
105+ for line in reader :
106+ token = line .strip ()
107+ self .token_dict [token ] = len (self .token_dict )
108+
109+
110+ def prepare_data (self ):
111+ neg = pd .read_excel (Config .bert .path_neg , header = None )
112+ pos = pd .read_excel (Config .bert .path_pos , header = None )
113+ data = []
114+ for d in neg [0 ]:
115+ data .append ((d , 0 ))
116+ for d in pos [0 ]:
117+ data .append ((d , 1 ))
118+ # 按照9:1的比例划分训练集和验证集
119+ random_order = list (range (len (data )))
120+ np .random .shuffle (random_order )
121+ train_data = [data [j ] for i , j in enumerate (random_order ) if i % 10 != 0 ]
122+ valid_data = [data [j ] for i , j in enumerate (random_order ) if i % 10 == 0 ]
123+ return train_data , valid_data
124+
125+ def build_model (self , m_type = "bert" ):
126+ if m_type == "bert" :
127+ bert_model = load_trained_model_from_checkpoint (self .path_config , self .path_checkpoint , seq_len = None )
128+ for l in bert_model .layers :
129+ l .trainable = True
130+ x1_in = Input (shape = (None ,))
131+ x2_in = Input (shape = (None ,))
132+ x = bert_model ([x1_in , x2_in ])
133+ x = Lambda (lambda x : x [:, 0 ])(x )
134+ p = Dense (1 , activation = 'sigmoid' )(x )#根据分类种类自行调节,也可以多加一些层数
135+ model = Model ([x1_in , x2_in ], p )
136+ model .compile (
137+ loss = 'binary_crossentropy' ,
138+ optimizer = Adam (1e-5 ), # 用足够小的学习率
139+ metrics = ['accuracy' ]
140+ )
141+ else :
142+ # 否则用 Embedding
143+ model = Sequential ()
144+ model .add (Embedding (len (vocab ), EMBED_DIM , mask_zero = True )) # Random embedding
145+ model .add (Bidirectional (LSTM (BiRNN_UNITS // 2 , return_sequences = True )))
146+ crf = CRF (len (chunk_tags ), sparse_target = True )
147+ model .add (crf )
148+ model .compile ('adam' , loss = crf .loss_function , metrics = [crf .accuracy ])
149+
150+ model .summary ()
151+ return model
152+
153+
96154#%%
97155# 加载数据
98156from keras_bert import Tokenizer
0 commit comments