GRAPH-BASED RECURRENT RETRIEVER

a new graph-based recurrent retrieval method

查找证据文档作为回答复杂问题的推理路径。

b:偏置项

  • 使用RNN建模问题$Q$的推理路径。
  • 给定问题$q$,在时间步$t$时,模型从候选段落集$C_t$中找出$p_i$ ,与$q$拼接计算$p_i$的概率。
  • 遇到$[EOE]$时结束推理,允许它在给定每个问题的情况下捕获具有任意长度的推理路径。

本文BERT结构:

image-20210108150704041

RNN结构:

$P(p_i|h_t)$:表示在时间步$t$选择段落$p_i$的概率。

image-20210108151355616

最终得到推理路径【$p_1,p_2$】

beam search

  • 通过束搜索得到给定时间步长的有限数量的最可能推理路径,减小输入BERT的数据量,减小计算量。
  • $C_1$是用在输入问题上 TF-IDF 得分最高的段落。
  • $C_t$是在C_1基础上,拓展的连接段落,用输入到BERT。
  • 推理路径$E = [pi, . . . , p_k]$乘段落概率$P(p_i|h_1) . . . P(p_k|h{|E|})$得到beam search 的输出,即得到top B 推理路径 $E = {E_1, . . . , E_B}$作为BERT输入,再将BERT输出作为RNN输入。

BERT相关

Bidirectional Encoder Representations from Transformers

是Google以无监督的方式利用大量无标注文本训练的的语言代表模型,其架构为Transformer中的Encoder。

  • 使用BERT预训练模型bert-base-uncased不区分大小写。

BERT 里5个特殊tokens:

  1. [CLS]:在做分类任务时其最后一层的repr. 会被视为整个输入序列的repr。

repr指的都是一个可以用来代表某词汇(在某个语境下)的多维连续向量(continuous vector)。

  1. [SEP]:有两个句子的文本会被串接成一个输入序列,并在两句之间插入这个token 以做区隔。
  2. [UNK]:没出现在BERT 字典里头的字会被这个token 取代。
  3. [PAD]:zero padding 遮罩,将长度不一的输入序列补齐方便做batch 运算。
  4. [MASK]:未知遮罩,仅在预训练阶段会用到。

代码实现

加载BERT预训练模型

1
model = BertForGraphRetriever.from_pretrained(args.bert_model,cache_dir=PYTORCH_PRETRAINED_BERT_CACHE / 'distributed_{}'.format(-1),graph_retriever_config=graph_retriever_config)

默认从缓存中加载,下载之后源码中替换自己本地路径即可。

  • any() 函数用于判断给定的可迭代参数 iterable 是否全部为 False,则返回 False,如果有一个为 True,则返回 True。

使用BertAdam自定义Adam优化器

1
2
3
4
5
optimizer = BertAdam(optimizer_grouped_parameters,
lr=args.learning_rate,
warmup=args.warmup_proportion,
t_total=t_total,
max_grad_norm=1.0)

在前10%的steps中,lr从0线性增加到 init_learning_rate,这个阶段又叫 warmup,然后,lr又从 init_learning_rate 线性衰减到0(完成所有steps)。

对问题和段落加上[CLS],[SEP]

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
def tokenize_question(question, tokenizer):
tokens_q = tokenizer.tokenize(question)
tokens_q = ['[CLS]'] + tokens_q + ['[SEP]']

return tokens_q


def tokenize_paragraph(p, tokens_q, max_seq_length, tokenizer):
tokens_p = tokenizer.tokenize(p)[:max_seq_length - len(tokens_q) - 1]
tokens_p = tokens_p + ['[SEP]']

padding = [0] * (max_seq_length - len(tokens_p) - len(tokens_q))

input_ids_ = tokenizer.convert_tokens_to_ids(tokens_q + tokens_p)
input_masks_ = [1] * len(input_ids_)
segment_ids_ = [0] * len(tokens_q) + [1] * len(tokens_p)

input_ids_ += padding
input_masks_ += padding
segment_ids_ += padding

assert len(input_ids_) == max_seq_length
assert len(input_masks_) == max_seq_length
assert len(segment_ids_) == max_seq_length

return input_ids_, input_masks_, segment_ids_

RNN初始化

1
self.rw = nn.Linear(2 * config.hidden_size, config.hidden_size)

通过beam search 找出top B 推理路径

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
b = 0
while b < beam:
s, p = torch.max(score.view(score.size(0) * score.size(1)), dim=0)
s = s.item()
p = p.item()
row = p // score.size(1)
col = p % score.size(1)

if j == 0:
score[:, col] = 0.0
else:
score[row, col] = 0.0

p = [[index for index in pred_[row][0]] + [col],
output[row].topk(k=2, dim=0)[1].tolist(),
s]
new_pred_.append(p)

p = [[p_ for p_ in prb] for prb in prob_[row]] + [output[row].tolist()]
new_prob_.append(p)

state_tmp[b].copy_(state_[row])
b += 1

实验

下载程序:

1
2
3
!git clone https://github.com/AkariAsai/learning_to_retrieve_reasoning_paths.git
%cd /content/learning_to_retrieve_reasoning_paths
!pip install -r requirements.txt

下载数据集

1
2
3
4
5
6
7
8
%cd /content/learning_to_retrieve_reasoning_paths
!mkdir data
%cd data
!mkdir hotpot
%cd hotpot
!gdown https://drive.google.com/uc?id=1AIRo66I2Izs80nNLt4MaLu7kqhTuIQ0u
!unzip hotpotqa_new_selector_train_data_db_2017_10_12_fix.zip.zip____
!rm hotpotqa_new_selector_train_data_db_2017_10_12_fix.zip.zip____

训练模型

1
%cd /content/learning_to_retrieve_reasoning_paths/graph_retriever
1
2
3
4
5
6
7
8
9
!python3 run_graph_retriever.py \
--task hotpot_distractor \
--bert_model bert-base-uncased --do_lower_case \
--train_file_path /content/learning_to_retrieve_reasoning_paths/data/hotpot/hotpotqa_new_selector_train_data_db_2017_10_12_fix/db=wiki_hotpotqa.db_hotpotqa_new_test_tfidf_k=50.pruning_l=100_tag_me=True.prune_after_agg=False.prune_in_article=False_use_link=True_start=0_end=5000.json \
--output_dir graph_retriever/ \
--max_para_num 10 \
--neg_chunk 8 --train_batch_size 4 --gradient_accumulation_steps 4 \
--learning_rate 3e-5 --num_train_epochs 3 \
--max_select_num 3
  • —max_para_num:与问题相关的段落数量。如果—max_para_num是n,问题的基础真实段落数量是k(2),那么有n-2个段落作为训练的反例。此时反例数量为8。

  • —neg_chunk:为了控制GPU内存消耗,将负例拆分为小块。

  • —max_select_num:指定模型推理步骤的最大数量,如果问题的基础真实段落数量是k,这个值应该为k+1,1表示结束符号EOE,此时k+1=3。