BART hub fixes + improvements (#1342)

Summary:
- Make BART hub interface extend from GeneratorHubInterface (fixes #1748)
- Add mask filling interface for BART

Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1342

Reviewed By: ngoyal2707

Differential Revision: D24264195

Pulled By: myleott

fbshipit-source-id: 0885f90a54fabe1672b1bfe137dfbccbc5d25d0e
This commit is contained in:
Myle Ott 2020-10-22 12:44:02 -07:00 committed by Facebook GitHub Bot
parent f0fcb55d5b
commit b8a938e96e
4 changed files with 119 additions and 67 deletions

View File

@ -131,6 +131,23 @@ bart.cuda()
bart.predict('new_task', tokens)
```
#### Filling masks:
BART can be used to fill multiple `<mask>` tokens in the input.
```python
bart = torch.hub.load('pytorch/fairseq', 'bart.base')
bart.eval()
bart.fill_mask('The cat <mask> on the <mask>.', topk=3, beam=10)
# [('The cat was on the ground.', tensor(-0.6183)), ('The cat was on the floor.', tensor(-0.6798)), ('The cat sleeps on the couch.', tensor(-0.6830))]
```
Note that by default we enforce the output length to match the input length.
This can be disabled by setting ``match_source_len=False``:
```
bart.fill_mask('The cat <mask> on the <mask>.', topk=3, beam=10, match_source_len=False)
# [('The cat was on the ground.', tensor(-0.6185)), ('The cat was asleep on the couch.', tensor(-0.6276)), ('The cat was on the floor.', tensor(-0.6800))]
```
#### Evaluating the `bart.large.mnli` model:
Example python code snippet to evaluate accuracy on the MNLI `dev_matched` set.

View File

@ -5,7 +5,7 @@
import copy
import logging
from typing import List
from typing import Dict, List
import numpy as np
import torch
@ -13,39 +13,22 @@ import torch.nn as nn
import torch.nn.functional as F
from fairseq import utils
from fairseq.data import encoders
from fairseq.hub_utils import GeneratorHubInterface
from omegaconf import open_dict
logger = logging.getLogger(__name__)
class BARTHubInterface(nn.Module):
class BARTHubInterface(GeneratorHubInterface):
"""A simple PyTorch Hub interface to BART.
Usage: https://github.com/pytorch/fairseq/tree/master/examples/bart
"""
def __init__(self, cfg, task, model):
super().__init__()
self.cfg = cfg
self.task = task
self.model = model
self.bpe = encoders.build_bpe(cfg.bpe)
self.max_positions = min(
utils.resolve_max_positions(
self.task.max_positions(),
self.model.max_positions(),
)
)
# this is useful for determining the device
self.register_buffer("_float_tensor", torch.tensor([0], dtype=torch.float))
@property
def device(self):
return self._float_tensor.device
super().__init__(cfg, task, [model])
self.model = self.models[0]
def encode(
self, sentence: str, *addl_sentences, no_separator=True
@ -70,8 +53,8 @@ class BARTHubInterface(nn.Module):
[0, 8331, 2]
"""
tokens = self.bpe.encode(sentence)
if len(tokens.split(" ")) > self.max_positions - 2:
tokens = " ".join(tokens.split(" ")[: self.max_positions - 2])
if len(tokens.split(" ")) > min(self.max_positions) - 2:
tokens = " ".join(tokens.split(" ")[: min(self.max_positions) - 2])
bpe_sentence = "<s> " + tokens + " </s>"
for s in addl_sentences:
bpe_sentence += " </s>" if not no_separator else ""
@ -104,50 +87,28 @@ class BARTHubInterface(nn.Module):
sample = utils.apply_to_sample(lambda tensor: tensor.to(self.device), sample)
return sample
def sample(
self, sentences: List[str], beam: int = 1, verbose: bool = False, **kwargs
) -> str:
input = [self.encode(sentence) for sentence in sentences]
hypos = self.generate(input, beam, verbose, **kwargs)
return [self.decode(x["tokens"]) for x in hypos]
def generate(
self,
tokens: List[torch.LongTensor],
beam: int = 5,
verbose: bool = False,
tokenized_sentences: List[torch.LongTensor],
*args,
inference_step_args=None,
**kwargs
) -> torch.LongTensor:
sample = self._build_sample(tokens)
# build generator using current args as well as any kwargs
gen_args = copy.copy(self.cfg)
with open_dict(gen_args):
gen_args.beam = beam
for k, v in kwargs.items():
setattr(gen_args, k, v)
generator = self.task.build_generator([self.model], gen_args)
translations = self.task.inference_step(
generator,
[self.model],
sample,
prefix_tokens=sample["net_input"]["src_tokens"]
.new_zeros((len(tokens), 1))
.fill_(self.task.source_dictionary.bos()),
) -> List[List[Dict[str, torch.Tensor]]]:
inference_step_args = inference_step_args or {}
if "prefix_tokens" in inference_step_args:
raise NotImplementedError("prefix generation not implemented for BART")
else:
bsz = len(tokenized_sentences)
inference_step_args["prefix_tokens"] = tokenized_sentences[0].new_full(
(bsz, 1), fill_value=self.task.source_dictionary.bos()
).to(device=self.device)
return super().generate(
tokenized_sentences,
*args,
inference_step_args=inference_step_args,
**kwargs
)
if verbose:
src_str_with_unk = self.string(tokens)
logger.info("S\t{}".format(src_str_with_unk))
def getarg(name, default):
return getattr(gen_args, name, getattr(self.args, name, default))
# Process top predictions
hypos = [x[0] for x in translations]
hypos = [v for _, v in sorted(zip(sample["id"].tolist(), hypos))]
return hypos
def extract_features(
self, tokens: torch.LongTensor, return_all_hiddens: bool = False
) -> torch.Tensor:
@ -201,3 +162,40 @@ class BARTHubInterface(nn.Module):
if return_logits:
return logits
return F.log_softmax(logits, dim=-1)
def fill_mask(
self,
masked_input: str,
topk: int = 5,
match_source_len: bool = True,
**generate_kwargs
):
masked_token = '<mask>'
assert masked_token in masked_input, \
"please add one {} token for the input".format(masked_token)
text_spans = masked_input.split(masked_token)
text_spans_bpe = (' {0} '.format(masked_token)).join(
[self.bpe.encode(text_span.rstrip()) for text_span in text_spans]
).strip()
tokens = self.task.source_dictionary.encode_line(
'<s> ' + text_spans_bpe + ' </s>',
append_eos=False,
add_if_not_exist=False,
).long()
if tokens.dim() == 1:
tokens = tokens.unsqueeze(0)
# ensure beam size is at least as big as topk
generate_kwargs['beam'] = max(
topk,
generate_kwargs.get('beam', -1),
)
generate_kwargs['match_source_len'] = match_source_len
hypos = self.generate(tokens, **generate_kwargs)[0]
return [
(self.decode(hypo['tokens']), hypo['score'])
for hypo in hypos[:topk]
]

View File

@ -293,7 +293,6 @@ class SequenceGenerator(nn.Module):
for step in range(max_len + 1): # one extra step for EOS marker
# reorder decoder internal states based on the prev choice of beams
# print(f'step: {step}')
if reorder_state is not None:
if batch_idxs is not None:
# update beam indices to take into account removed sentences
@ -635,12 +634,11 @@ class SequenceGenerator(nn.Module):
else:
cum_unfin.append(prev)
# set() is not supported in script export
# The keys here are of the form "{sent}_{unfin_idx}", where
# "unfin_idx" is the index in the current (possibly reduced)
# list of sentences, and "sent" is the index in the original,
# unreduced batch
# set() is not supported in script export
sents_seen: Dict[str, Optional[Tensor]] = {}
# For every finished beam item
@ -651,7 +649,6 @@ class SequenceGenerator(nn.Module):
unfin_idx = idx // beam_size
# sentence index in the original (unreduced) batch
sent = unfin_idx + cum_unfin[unfin_idx]
# print(f"{step} FINISHED {idx} {score} {sent}={unfin_idx} {cum_unfin}")
# Cannot create dict for key type '(int, int)' in torchscript.
# The workaround is to cast int to string
seen = str(sent.item()) + "_" + str(unfin_idx.item())

View File

@ -11,6 +11,10 @@ from fairseq.data import (
AppendTokenDataset,
DenoisingDataset,
Dictionary,
IdDataset,
NestedDictionaryDataset,
NumelDataset,
PadDataset,
PrependTokenDataset,
StripTokenDataset,
TokenBlockDataset,
@ -18,6 +22,7 @@ from fairseq.data import (
)
from fairseq.data.encoders.utils import get_whole_word_mask
from fairseq.tasks import LegacyFairseqTask, register_task
import numpy as np
logger = logging.getLogger(__name__)
@ -195,6 +200,41 @@ class DenoisingTask(LegacyFairseqTask):
)
)
def build_dataset_for_inference(self, src_tokens, src_lengths, **kwargs):
"""
Generate batches for inference. We assume that the input begins with a
bos symbol (`<s>`) and ends with an eos symbol (`</s>`).
"""
pad = self.source_dictionary.pad()
eos = self.source_dictionary.eos()
src_dataset = TokenBlockDataset(
src_tokens,
src_lengths,
block_size=self.args.tokens_per_sample - 2, # for <s> and </s>
pad=pad,
eos=eos,
break_mode=self.args.sample_break_mode,
document_sep_len=0,
)
prev_output_tokens = PrependTokenDataset(
StripTokenDataset(src_dataset, eos), eos
)
src_dataset = PadDataset(src_dataset, pad_idx=pad, left_pad=False)
return NestedDictionaryDataset(
{
"id": IdDataset(),
"net_input": {
"src_tokens": src_dataset,
"src_lengths": NumelDataset(src_dataset, reduce=False),
"prev_output_tokens": PadDataset(
prev_output_tokens, pad_idx=pad, left_pad=False
),
},
"target": src_dataset,
},
sizes=[np.array(src_lengths)],
)
def max_positions(self):
"""Return the max sentence length allowed by the task."""
return (self.args.max_source_positions, self.args.max_target_positions)