READING AND ANSWERING GIVEN REASONING PATHS

reader model

多任务阅读器模型

多任务阅读器模型

  • 阅读理解任务

    使用BERT从推理路径中提取答案范围。

  • 对推理路径重排序

    使用Bert模型对应于CLS标识符位的输出判断推理路径包括答案的概率。根据概率对推理路径重新排序。

$w_n∈ R^D$:权重向量

$P(E|q)$:推理路径E的概率

$E_{best}$:最佳路径

$S_{read}$:正确答案的范围

$P^{start}i,P^{end}_j$表示$E{best}$中第i个token和第j个token分别为开始位置和结束位置的概率

  • 增加负例数据:

​ 为了训练我们的读者模型来区分相关和不相关的推理路径,我们对原始训练数据进行了补充,并附加了其他负面示例来模拟不完全的证据。

  • 损失函数

目标是跨度预测和重新排序任务的交叉熵损失之和。 问题q及其候选证据E的损失:

$y^{start}, y^{end}$是 ground-truth的开始和结束。

$L_{no_answer}$:重新re-ranking model的损失,辨别没有答案的失真推理路径。

$P^r$: if E is the ground-truth evidence; $P^r= P(E|q)$,otherwise $P^r= 1 − P(E|q)$.

屏蔽了负样本跨度损失,以避免对跨度预测产生意外影响。

代码实现

优化器:

1
2
3
4
optimizer = BertAdam(optimizer_grouped_parameters,
lr=args.learning_rate,
warmup=args.warmup_proportion,
t_total=num_train_optimization_steps)

使用dataloader加载训练数据

1
2
3
train_sampler = DistributedSampler(train_data)
train_dataloader = DataLoader(
train_data, sampler=train_sampler, batch_size=args.train_batch_size)

显示进度

1
for _ in trange(int(args.num_train_epochs), desc="Epoch"):

Epoch: 100%|██████████| 3/3 [00:03<00:00, 1.00s/it]

创建模型

1
2
3
4
5
6
model = BertForQuestionAnsweringConfidence.from_pretrained(args.bert_model,
cache_dir=os.path.join(
str(PYTORCH_PRETRAINED_BERT_CACHE), 'distributed_{}'.format(args.local_rank)),
num_labels=4,
no_masking=args.no_masking,
lambda_scale=args.lambda_scale)

损失函数

1
2
loss = model(input_ids=input_ids, token_type_ids=segment_ids, attention_mask=input_mask,
start_positions=start_positions, end_positions=end_positions, switch_list=switches)

$L{span}+L{no_answer}$

1
2
3
4
span_mask = (switch_list == 0).type(torch.FloatTensor).cuda()
start_losses = loss_fct(
start_logits, start_positions) * span_mask
end_losses = loss_fct(end_logits, end_positions) * span_mask
1
2
3
collections.OrderedDict()#实现对字典元素排序
collections.namedtuple()#类似于结构体的用法
json.dumps() #将python对象编码成Json字符串

仅使用可以找大答案的问题

1
2
3
4
5
6
7
8
9
actual_text = " ".join(
doc_tokens[start_position:(end_position + 1)])
cleaned_answer_text = " ".join(
whitespace_tokenize(orig_answer_text))
if actual_text.find(cleaned_answer_text) == -1:
logger.warning("Could not find answer: '%s' vs. '%s'",
actual_text, cleaned_answer_text)
continue

过滤掉过长的问题

1
2
3
4
if len(orig_answer_text.split()) > max_answer_len:
logger.info(
"Omitting a long answer: '%s'", orig_answer_text)
continue
1
num_train_optimization_steps // int(save_chunk) #结果取整数

将答案范围设置为与a匹配并首先出现的字符串。

1
2
3
4
for i in range(len(index_and_score)):
if i >= n_best_size:
break
best_indexes.append(index_and_score[i][0])

BRET输入示例

1
tokens: [CLS] this singer of a rather blu ##ster ##y day also voiced what hedge ##hog ? [SEP] " a rather blu ##ster ##y day " is a w ##him ##sic ##al song from the walt disney musical film feature ##tte , " winnie the po ##oh and the blu ##ster ##y day " . it was written by robert & richard sherman and sung by jim cummings as " po ##oh " . james jonah cummings ( born november 3 , 1952 ) is an american voice actor and singer , who has appeared in almost 400 roles . he is known for vo ##icing the title character from " dark ##wing duck " , dr . robot ##nik from " sonic the hedge ##hog " , and pete . his other characters include winnie the po ##oh , ti ##gger , and the tasmanian devil . he has performed in numerous disney and dream ##works animation ##s including " ala ##ddin " , " the lion king " , " bal ##to " , " ant ##z " , " the road to el dora ##do " , " sh ##rek " , and " the princess and the frog " . he has also provided voice - over work for video games , such as " ice ##wind dale " , " fallout " , " " , " bald ##ur ' s gate " , " mass effect 2 " , " " , " " , " " , and " sp ##lat ##ter ##house " . [SEP]

输入长度不足补0

1
2
3
4
5
while len(input_ids) < max_seq_length:
input_ids.append(pad_token)
input_mask.append(0 if mask_padding_with_zero else 1)
segment_ids.append(pad_token_segment_id)
p_mask.append(1)

示例

1
input_mask: 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0

实验

下载程序

1
2
3
!git clone https://github.com/AkariAsai/learning_to_retrieve_reasoning_paths.git
%cd /content/learning_to_retrieve_reasoning_paths
!pip install -r requirements.txt

下载训练数据

1
2
3
4
5
6
7
8
9
10
11
12
13
%cd /content/learning_to_retrieve_reasoning_paths
!mkdir data
%cd data
!mkdir hotpot
%cd hotpot
!gdown https://drive.google.com/uc?id=1_a8KliAHKIwrYRrHgHOlzM0Jon3AqZLs
!mv hotpot_reader_train_data.json.json____ hotpot_reader_train_data.json
!gdown https://drive.google.com/uc?id=1R4exuPDaV2yD18xUBsnNyQpXwn0ty5pc
!mv nq_reader_train_data_public.json.json____ nq_reader_train_data_public.json
!gdown https://drive.google.com/uc?id=1FB5gB9aM8rmbpIwYf-1o6lmxMYQsg_rP
!mv squad_reader_train_data.json.json____ squad_reader_train_data.json
!gdown https://drive.google.com/uc?id=1MysthH2TRYoJcK_eLOueoLeYR42T-JhB
!ls

训练模型

数据集均采用SQuAD v.2 format

使用hotpot_dev_squad_v2.0_format.json训练

hotpot_reader_train_data.json太大训练不出来,这里用hotpot_dev_squad_v2.0_format.json只是为了体验下训练的过程!

1
2
3
4
5
6
7
8
9
10
11
12
%cd /content/learning_to_retrieve_reasoning_paths/reader/
!python run_reader_confidence.py \
--bert_model bert-base-uncased \
--output_dir output_hotpot_bert_base \
--train_file /content/learning_to_retrieve_reasoning_paths/data/hotpot/hotpot_dev_squad_v2.0_format.json \
--predict_file /content/learning_to_retrieve_reasoning_paths/data/hotpot/hotpot_dev_squad_v2.0_format.json \
--max_seq_length 384 \
--do_train \
--do_predict \
--do_lower_case \
--version_2_with_negative \
--train_batch_size 16

—train_batch_size 根据显卡内存调整

—version_2_with_negative 使用负例数据训练

仅训练

1
2
3
4
5
6
7
8
9
10
%cd /content/learning_to_retrieve_reasoning_paths/reader/
!python run_reader_confidence.py \
--bert_model bert-base-uncased \
--output_dir output_hotpot_bert_base \
--train_file /content/learning_to_retrieve_reasoning_paths/data/hotpot/hotpot_dev_squad_v2.0_format.json \
--max_seq_length 384 \
--do_train \
--do_lower_case \
--version_2_with_negative \
--train_batch_size 16

仅预测

1
2
3
4
5
6
7
8
9
10
%cd /content/learning_to_retrieve_reasoning_paths/reader/
!python run_reader_confidence.py \
--bert_model bert-base-uncased \
--output_dir output_hotpot_bert_base \
--predict_file /content/learning_to_retrieve_reasoning_paths/data/hotpot/hotpot_dev_squad_v2.0_format.json \
--max_seq_length 384 \
--do_predict \
--do_lower_case \
--version_2_with_negative \
--train_batch_size 16

评估

下载评估数据集

1
2
3
4
5
6
7
%cd /content/learning_to_retrieve_reasoning_paths
!mkdir data
%cd data
!mkdir hotpot
%cd hotpot
!gdown https://drive.google.com/uc?id=1MysthH2TRYoJcK_eLOueoLeYR42T-JhB
!ls

评估模型

predictions.json 为模型预测后自动生成

1
2
3
4
5
6
%cd /content/learning_to_retrieve_reasoning_paths/reader/
!wget https://worksheets.codalab.org/rest/bundles/0x6b567e1cf2e041ec80d7098f031c5c9e/contents/blob/
!mv index.html evaluate-v2.0.py
!python evaluate-v2.0.py \
/content/learning_to_retrieve_reasoning_paths/data/hotpot/hotpot_dev_squad_v2.0_format.json \
/content/learning_to_retrieve_reasoning_paths/reader/output_hotpot_bert_base/predictions.json

评估数据:/content/learning_to_retrieve_reasoning_paths/data/hotpot/hotpot_dev_squad_v2.0_format.json