Commit f80767af by zhongshen zeng

submit prefix tuning code

parents
__pycache__/
THUCNews/
datasets/
sim_text_data/
cmrc_runs/
cmrc2018
checkpoint_model/
.ipynb_checkpoints/
.vscode/
# ===========
# base images
# ===========
FROM nvcr.io/nvidia/pytorch:21.06-py3
# ============
# pip packages
# ============
# RUN pip install --upgrade pip && \
# pip install --upgrade setuptools
RUN pip install sentencepiece
RUN pip install tqdm
RUN pip install deepspeed==0.5.4
RUN pip install pyemd
RUN pip install jieba
RUN pip install transformers
\ No newline at end of file
# Chinese-Transformer-XL
Under construction
本项目提供了智源研究院"文汇"
预训练模型Chinese-Transformer-XL的预训练和文本生成代码。[[应用主页]](https://gpt-3.aminer.cn/) [[模型下载]](http://dorc-model-team.ks3-cn-beijing.ksyun.com/ren-zhi/my-model/mp_rank_00_model_states.pt)
## 数据
本模型使用了智源研究院发布的中文预训练语料[WuDaoCorpus](https://data.baai.ac.cn/data-set-details/0c8dc71dd06ae75a10ca422fb49b0751)
。具体地,我们使用了WuDaoCorpus中来自百度百科+搜狗百科(133G)、知乎(131G)、百度知道(38G)的语料,一共303GB数据。
## 模型
本模型使用了[GPT-3](https://arxiv.org/abs/2005.14165)
的训练目标,同时使用能够更好地处理长序列建模的[Transformer-XL](https://arxiv.org/abs/1901.02860) 替代了GPT中的Transformer。模型的结构与GPT-3
2.7B(32层,隐表示维度2560,每层32个注意力头)基本相同,因为Transformer-XL的结构改动,模型参数增加到了29亿。
## 结果
为了验证模型的生成能力,我们在中文的开放域长文问答上进行了评测。我们从[知乎](https://www.zhihu.com)
上随机选择了100个不同领域的、不在训练语料中的问题。对每个问题,由人类测试员对一个高赞同数回答、3个模型生成的回答和3个[CPM](https://github.com/TsinghuaAI/CPM-Generate)
生成的回答在流畅度、信息量、相关度、总体四个维度进行打分。测评结果如下:
|模型|流畅度(1-5)|信息量(1-5)|相关度(1-5)|总体(1-10)|
|---|---|---|---|---|
|CPM|2.66|2.47|2.36|4.32|
|文汇|3.44|3.25|3.21|5.97|
|人类答案|3.80|3.61|3.67|6.85|
可以看到相比起CPM,"文汇"更接近人类所写的高赞答案。
## 安装
根据`requirements.txt`安装pytorch等基础依赖
```shell
pip install -r requirements.txt
```
如果要finetune模型参数,还需要安装[DeepSpeed](https://github.com/microsoft/DeepSpeed)
```shell
DS_BUILD_OPS=1 pip install deepspeed
```
## 推理
首先下载模型的[checkpoint](http://dorc-model-team.ks3-cn-beijing.ksyun.com/ren-zhi/my-model/mp_rank_00_model_states.pt) ,目录结构如下
```
.
└─ txl-2.9B
└─ mp_rank_00_model_states.pt
```
然后运行交互式生成脚本
```shell
bash scripts/generate_text.sh ./txl-2.9B
```
## Finetune
模型的finetune基于使用DeepSpeed。首先在`scripts/ds_finetune_gpt_2.9B.sh`中修改`NUM_WORKERS``NUM_GPUS_PER_WORKER`
为使用的节点数目和每个节点的GPU数量。如果使用多机训练的话,还要修改`HOST_FILE_PATH`
为hostfile的路径(DeepSpeed使用[OpenMPI风格的hostfile](https://www.deepspeed.ai/getting-started/#resource-configuration-multi-node)
)。
然后运行finetune脚本
```shell
bash scripts/ds_finetune_gpt_2.9B.sh ./txl-2.9B ./data.json
```
其中`./txl-2.9B`为checkpoint目录。`./data.json`为finetune数据,格式为[jsonl文件](https://jsonlines.org/)
,每条数据的格式为`{"prompt": .., "text": ...}`。其中prompt为生成的context,text为生成的内容。
如果你在finetune的遇到了OOM错误(一般是因为GPU数量或者显存不足导致的),可以尝试在[scripts/ds_config_2.9B_finetune.json](scripts/ds_config_2.9B_finetune.json)`zero_optimization`部分添加`"cpu_offload": true`,来开启[ZeRO-Offload](https://www.deepspeed.ai/tutorials/zero-offload/) 以减少显存消耗。
## 模型并行
如果你的显存大小比较有限,可以尝试使用模型并行来减少显存消耗。我们提供的模型checkpoint是在单卡上运行的。首先使用[change_mp.py](change_mp.py)来对hceckpoint进行切分
```shell
python change_mp.py ./txl-2.9B 2
```
其中2表示2路模型并行。在推理和finetune的时候,将脚本中的MP_SIZE改为2,然后使用./txl-2.9B_MP2作为运行脚本时的checkpoint路径。
## 引用
\ No newline at end of file
# coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""argparser configuration"""
import argparse
import os
import torch
import json
def add_model_config_args(parser):
"""Model arguments"""
group = parser.add_argument_group('model', 'model configuration')
group.add_argument('--transformer-xl', action='store_true', help='use transformer-xl for training')
group.add_argument('--pretrained-bert', action='store_true',
help='use a pretrained bert-large-uncased model instead'
'of initializing from scratch. See '
'--tokenizer-model-type to specify which pretrained '
'BERT model to use')
group.add_argument('--attention-dropout', type=float, default=0.1,
help='dropout probability for attention weights')
group.add_argument('--num-attention-heads', type=int, default=16,
help='num of transformer attention heads')
group.add_argument('--hidden-size', type=int, default=1024,
help='tansformer hidden size')
group.add_argument('--intermediate-size', type=int, default=None,
help='transformer embedding dimension for FFN'
'set to 4*`--hidden-size` if it is None')
group.add_argument('--num-layers', type=int, default=24,
help='num decoder layers')
group.add_argument('--layernorm-epsilon', type=float, default=1e-5,
help='layer norm epsilon')
group.add_argument('--hidden-dropout', type=float, default=0.1,
help='dropout probability for hidden state transformer')
group.add_argument('--max-position-embeddings', type=int, default=512,
help='maximum number of position embeddings to use')
group.add_argument('--vocab-size', type=int, default=30522,
help='vocab size to use for non-character-level '
'tokenization. This value will only be used when '
'creating a tokenizer')
group.add_argument('--deep-init', action='store_true',
help='initialize bert model similar to gpt2 model.'
'scales initialization of projection layers by a '
'factor of 1/sqrt(2N). Necessary to train bert '
'models larger than BERT-Large.')
group.add_argument('--make-vocab-size-divisible-by', type=int, default=128,
help='Pad the vocab size to be divisible by this value.'
'This is added for computational efficieny reasons.')
group.add_argument('--cpu-optimizer', action='store_true',
help='Run optimizer on CPU')
group.add_argument('--cpu_torch_adam', action='store_true',
help='Use Torch Adam as optimizer on CPU.')
return parser
def add_fp16_config_args(parser):
"""Mixed precision arguments."""
group = parser.add_argument_group('fp16', 'fp16 configurations')
group.add_argument('--fp16', action='store_true',
help='Run model in fp16 mode')
group.add_argument('--fp32-embedding', action='store_true',
help='embedding in fp32')
group.add_argument('--fp32-layernorm', action='store_true',
help='layer norm in fp32')
group.add_argument('--fp32-tokentypes', action='store_true',
help='embedding token types in fp32')
group.add_argument('--fp32-allreduce', action='store_true',
help='all-reduce in fp32')
group.add_argument('--hysteresis', type=int, default=2,
help='hysteresis for dynamic loss scaling')
group.add_argument('--loss-scale', type=float, default=None,
help='Static loss scaling, positive power of 2 '
'values can improve fp16 convergence. If None, dynamic'
'loss scaling is used.')
group.add_argument('--loss-scale-window', type=float, default=1000,
help='Window over which to raise/lower dynamic scale')
group.add_argument('--min-scale', type=float, default=1,
help='Minimum loss scale for dynamic loss scale')
return parser
def add_training_args(parser):
"""Training arguments."""
group = parser.add_argument_group('train', 'training configurations')
group.add_argument('--experiment-name', type=str, default="gpt-345M",
help="The experiment name for summary and checkpoint")
group.add_argument('--batch-size', type=int, default=4,
help='Data Loader batch size')
group.add_argument('--weight-decay', type=float, default=0.01,
help='weight decay coefficient for L2 regularization')
group.add_argument('--checkpoint-activations', action='store_true',
help='checkpoint activation to allow for training '
'with larger models and sequences')
group.add_argument('--checkpoint-num-layers', type=int, default=1,
help='chunk size (number of layers) for checkpointing')
group.add_argument('--deepspeed-activation-checkpointing', action='store_true',
help='uses activation checkpointing from deepspeed')
group.add_argument('--clip-grad', type=float, default=1.0,
help='gradient clipping')
group.add_argument('--train-iters', type=int, default=1000000,
help='total number of iterations to train over all training runs')
group.add_argument('--log-interval', type=int, default=100,
help='report interval')
group.add_argument('--exit-interval', type=int, default=None,
help='Exit the program after this many new iterations.')
group.add_argument('--summary-dir', type=str, default="", help="The directory to store the summary")
group.add_argument('--seed', type=int, default=1234,
help='random seed')
# Batch prodecuer arguments
group.add_argument('--reset-position-ids', action='store_true',
help='Reset posistion ids after end-of-document token.')
group.add_argument('--reset-attention-mask', action='store_true',
help='Reset self attention maske after '
'end-of-document token.')
# Learning rate.
group.add_argument('--lr-decay-iters', type=int, default=None,
help='number of iterations to decay LR over,'
' If None defaults to `--train-iters`*`--epochs`')
group.add_argument('--lr-decay-style', type=str, default='linear',
choices=['constant', 'linear', 'cosine', 'exponential'],
help='learning rate decay function')
group.add_argument('--lr-decay-ratio', type=float, default=0.1)
group.add_argument('--lr', type=float, default=1.0e-4,
help='initial learning rate')
group.add_argument('--warmup', type=float, default=0.01,
help='percentage of data to warmup on (.01 = 1% of all '
'training iters). Default 0.01')
# model checkpointing
group.add_argument('--save', type=str, default=None,
help='Output directory to save checkpoints to.')
group.add_argument('--save-interval', type=int, default=5000,
help='number of iterations between saves')
group.add_argument('--no-save-optim', action='store_true',
help='Do not save current optimizer.')
group.add_argument('--no-save-rng', action='store_true',
help='Do not save current rng state.')
group.add_argument('--load', type=str, default=None,
help='Path to a directory containing a model checkpoint.')
group.add_argument('--no-load-optim', action='store_true',
help='Do not load optimizer when loading checkpoint.')
group.add_argument('--no-load-lr-scheduler', action='store_true',
help='Do not load lr scheduler when loading checkpoint.')
group.add_argument('--no-load-rng', action='store_true',
help='Do not load rng state when loading checkpoint.')
group.add_argument('--finetune', action='store_true',
help='Load model for finetuning. Do not load optimizer '
'or rng state from checkpoint and set iteration to 0. '
'Assumed when loading a release checkpoint.')
group.add_argument('--resume-dataloader', action='store_true',
help='Resume the dataloader when resuming training. '
'Does not apply to tfrecords dataloader, try resuming'
'with a different seed in this case.')
# distributed training args
group.add_argument('--distributed-backend', default='nccl',
help='which backend to use for distributed '
'training. One of [gloo, nccl]')
group.add_argument('--local_rank', type=int, default=None,
help='local rank passed from distributed launcher')
return parser
def add_evaluation_args(parser):
"""Evaluation arguments."""
group = parser.add_argument_group('validation', 'validation configurations')
group.add_argument('--eval-batch-size', type=int, default=None,
help='Data Loader batch size for evaluation datasets.'
'Defaults to `--batch-size`')
group.add_argument('--eval-iters', type=int, default=100,
help='number of iterations to run for evaluation'
'validation/test for')
group.add_argument('--eval-interval', type=int, default=1000,
help='interval between running evaluation on validation set')
group.add_argument('--eval-seq-length', type=int, default=None,
help='Maximum sequence length to process for '
'evaluation. Defaults to `--seq-length`')
group.add_argument('--eval-max-preds-per-seq', type=int, default=None,
help='Maximum number of predictions to use for '
'evaluation. Defaults to '
'math.ceil(`--eval-seq-length`*.15/10)*10')
group.add_argument('--overlapping-eval', type=int, default=32,
help='sliding window for overlapping eval ')
group.add_argument('--cloze-eval', action='store_true',
help='Evaluation dataset from `--valid-data` is a cloze task')
group.add_argument('--eval-hf', action='store_true',
help='perform evaluation with huggingface openai model.'
'use `--load` to specify weights path to be loaded')
group.add_argument('--load-openai', action='store_true',
help='load openai weights into our model. Use `--load` '
'to specify weights path to be loaded')
return parser
def add_text_generate_args(parser):
"""Text generate arguments."""
group = parser.add_argument_group('Text generation', 'configurations')
group.add_argument("--temperature", type=float, default=1.0)
group.add_argument("--top_p", type=float, default=0.85)
group.add_argument("--top_k", type=int, default=0)
group.add_argument("--out-seq-length", type=int, default=256)
group.add_argument("--hierarchical", action='store_true')
group.add_argument("--test_data_path", type=str, default=None)
group.add_argument("--dist_lambda", type=float, default=5.0)
group.add_argument("--min_dist_lambda", type=float, default=5.0)
group.add_argument("--dist_lambda_decay_rate", type=float, default=5.0)
group.add_argument("--perplexity_lambda", type=float, default=5.0)
group.add_argument("--beam_number", type=int, default=5)
group.add_argument("--min_beam_number", type=int, default=5)
group.add_argument("--beam_token_number", type=int, default=5)
group.add_argument("--beam_candidate_number", type=int, default=5)
group.add_argument("--min_beam_candidate_number", type=int, default=5)
group.add_argument("--max_seq_len", type=int, default=5)
group.add_argument("--sent_generate_num", type=int, default=5)
return parser
def add_data_args(parser):
"""Train/valid/test data arguments."""
group = parser.add_argument_group('data', 'data configurations')
group.add_argument('--model-parallel-size', type=int, default=1,
help='size of the model parallel.')
group.add_argument('--shuffle', action='store_true',
help='Shuffle data. Shuffling is deterministic '
'based on seed and current epoch.')
group.add_argument('--train-data', nargs='+', default=None,
help='Whitespace separated filenames or corpora names '
'for training.')
group.add_argument('--xl-dataset', action='store_true')
group.add_argument('--use-npy-data-loader', action='store_true',
help='Use the numpy data loader. If set, then'
'train-data-path, val-data-path, and test-data-path'
'should also be provided.')
group.add_argument('--train-data-path', type=str, default='',
help='path to the training data')
group.add_argument('--val-data-path', type=str, default='',
help='path to the validation data')
group.add_argument('--test-data-path', type=str, default='',
help='path to the test data')
group.add_argument('--input-data-sizes-file', type=str, default='sizes.txt',
help='the filename containing all the shards sizes')
group.add_argument('--delim', default=',',
help='delimiter used to parse csv data files')
group.add_argument('--text-key', default='sentence',
help='key to use to extract text from json/csv')
group.add_argument('--eval-text-key', default=None,
help='key to use to extract text from '
'json/csv evaluation datasets')
group.add_argument('--valid-data', nargs='*', default=None,
help="""Filename for validation data.""")
group.add_argument('--split', default='1000,1,1',
help='comma-separated list of proportions for training,'
' validation, and test split')
group.add_argument('--test-data', nargs='*', default=None,
help="""Filename for testing""")
group.add_argument('--lazy-loader', action='store_true',
help='whether to lazy read the data set')
group.add_argument('--loose-json', action='store_true',
help='Use loose json (one json-formatted string per '
'newline), instead of tight json (data file is one '
'json string)')
group.add_argument('--presplit-sentences', action='store_true',
help='Dataset content consists of documents where '
'each document consists of newline separated sentences')
group.add_argument('--num-workers', type=int, default=2,
help="""Number of workers to use for dataloading""")
group.add_argument('--tokenizer-model-type', type=str,
default='bert-large-uncased',
help="Model type to use for sentencepiece tokenization \
(one of ['bpe', 'char', 'unigram', 'word']) or \
bert vocab to use for BertWordPieceTokenizer (one of \
['bert-large-uncased', 'bert-large-cased', etc.])")
group.add_argument('--tokenizer-path', type=str, default='tokenizer.model',
help='path used to save/load sentencepiece tokenization '
'models')
group.add_argument('--tokenizer-type', type=str,
default='BertWordPieceTokenizer',
choices=['CharacterLevelTokenizer',
'SentencePieceTokenizer',
'BertWordPieceTokenizer',
'GPT2BPETokenizer',
'ChineseSPTokenizer'],
help='what type of tokenizer to use')
group.add_argument('--not-pre-tokenize', action='store_true')
group.add_argument("--cache-dir", default=None, type=str,
help="Where to store pre-trained BERT downloads")
group.add_argument('--use-tfrecords', action='store_true',
help='load `--train-data`, `--valid-data`, '
'`--test-data` from BERT tf records instead of '
'normal data pipeline')
group.add_argument('--seq-length', type=int, default=512,
help="Maximum sequence length to process")
group.add_argument('--mem-length', type=int, default=0,
help="The memory length to preserve")
group.add_argument('--max-preds-per-seq', type=int, default=None,
help='Maximum number of predictions to use per sequence.'
'Defaults to math.ceil(`--seq-length`*.15/10)*10.'
'MUST BE SPECIFIED IF `--use-tfrecords` is True.')
group.add_argument('--sample-one-document', action='store_true', help='only sample one document in one sample')
return parser
def add_prefix_tuning_args(parser):
"""prefix-tuning arguments."""
group = parser.add_argument_group('prefix-tuning', 'prefix tuning configurations')
group.add_argument('--save_dir', type=str, default="",
help='directory for model checkpoint saving')
group.add_argument('--preseqlen', type=int, default=100,
help="prefix sequence length")
group.add_argument('--load_sd', type=str, default="",
help="path to load state dict of model")
group.add_argument('--no_load_lr_scheduler', action='store_true', help='do not load lr scheduler when loading checkpoint')
group.add_argument('--no_load_optim', action='store_true', help='do not load optimizer when loading checkpoint')
group.add_argument('--batch_size', type=int, help='batch size')
group.add_argument("--MASTER_PORT", type=str)
return parser
def get_args():
"""Parse all the args."""
parser = argparse.ArgumentParser(description='PyTorch BERT Model')
parser = add_model_config_args(parser)
parser = add_fp16_config_args(parser)
parser = add_training_args(parser)
parser = add_evaluation_args(parser)
parser = add_text_generate_args(parser)
parser = add_data_args(parser)
parser = add_prefix_tuning_args(parser)
# Include DeepSpeed configuration arguments
try:
import deepspeed
parser = deepspeed.add_config_arguments(parser)
except ModuleNotFoundError:
pass
args = parser.parse_args()
if not args.train_data and not args.train_data_path:
print('WARNING: No training data specified')
args.cuda = torch.cuda.is_available()
args.rank = int(os.getenv('RANK', '0'))
args.world_size = int(os.getenv("WORLD_SIZE", '1'))
if hasattr(args, 'deepspeed_mpi') and args.deepspeed_mpi:
mpi_define_env(args)
elif os.getenv('OMPI_COMM_WORLD_LOCAL_RANK'):
# We are using (OpenMPI) mpirun for launching distributed data parallel processes
local_rank = int(os.getenv('OMPI_COMM_WORLD_LOCAL_RANK'))
local_size = int(os.getenv('OMPI_COMM_WORLD_LOCAL_SIZE'))
# Possibly running with Slurm
num_nodes = int(os.getenv('SLURM_JOB_NUM_NODES', '1'))
nodeid = int(os.getenv('SLURM_NODEID', '0'))
args.local_rank = local_rank
args.rank = nodeid * local_size + local_rank
args.world_size = num_nodes * local_size
args.model_parallel_size = min(args.model_parallel_size, args.world_size)
if args.rank == 0:
print('using world size: {} and model-parallel size: {} '.format(
args.world_size, args.model_parallel_size))
args.dynamic_loss_scale = False
if args.loss_scale is None:
args.dynamic_loss_scale = True
if args.rank == 0:
print(' > using dynamic loss scaling')
# The args fp32_* or fp16_* meant to be active when the
# args fp16 is set. So the default behaviour should all
# be false.
if not args.fp16:
args.fp32_embedding = False
args.fp32_tokentypes = False
args.fp32_layernorm = False
if hasattr(args, "deepspeed") and args.deepspeed and args.deepspeed_config is not None:
with open(args.deepspeed_config) as file:
deepspeed_config = json.load(file)
if "train_micro_batch_size_per_gpu" in deepspeed_config:
args.batch_size = deepspeed_config["train_micro_batch_size_per_gpu"]
if "gradient_accumulation_steps" in deepspeed_config:
args.gradient_accumulation_steps = deepspeed_config["gradient_accumulation_steps"]
else:
args.gradient_accumulation_steps = None
if "optimizer" in deepspeed_config:
optimizer_params_config = deepspeed_config["optimizer"].get("params", {})
args.lr = optimizer_params_config.get("lr", args.lr)
args.weight_decay = optimizer_params_config.get("weight_decay", args.weight_decay)
return args
def mpi_define_env(args):
from mpi4py import MPI
import subprocess
comm = MPI.COMM_WORLD
rank = comm.Get_rank()
world_size = comm.Get_size()
master_addr = None
if rank == 0:
hostname_cmd = ["hostname -I"]
result = subprocess.check_output(hostname_cmd, shell=True)
master_addr = result.decode('utf-8').split()[0]
master_addr = comm.bcast(master_addr, root=0)
# Determine local rank by assuming hostnames are unique
proc_name = MPI.Get_processor_name()
all_procs = comm.allgather(proc_name)
local_rank = sum([i == proc_name for i in all_procs[:rank]])
os.environ['RANK'] = str(rank)
os.environ['WORLD_SIZE'] = str(world_size)
args.local_rank = local_rank
args.world_size = world_size
args.rank = rank
os.environ['MASTER_ADDR'] = master_addr
print(f"args master port : {args.MASTER_PORT}")
os.environ['MASTER_PORT'] = args.MASTER_PORT # TORCH_DISTRIBUTED_DEFAULT_PORT = 29500
print(
"Discovered MPI settings of world_rank={}, local_rank={}, world_size={}, master_addr={}, master_port={}"
.format(os.environ['RANK'],
args.local_rank,
os.environ['WORLD_SIZE'],
os.environ['MASTER_ADDR'],
os.environ['MASTER_PORT']))
\ No newline at end of file
# coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""utils for creating datasets"""
import os
import random
import time
import torch
from .samplers import DistributedBatchSampler
from .datasets import json_dataset, csv_dataset, split_ds, ConcatDataset, SplitDataset, bert_sentencepair_dataset, \
GPT2Dataset, ShuffleDataset, XLDataset
from .lazy_loader import exists_lazy, LazyWriter, LazyLoader
from .tokenization import Tokenization, CommandToken, Tokenizer, CharacterLevelTokenizer, BertWordPieceTokenizer, \
GPT2BPETokenizer, make_tokenizer
from . import corpora
TRAIN_DATA = 0
VAL_DATA = 1
TEST_DATA = 2
def should_split(split):
"""
given split proportions checks if should split
Examples:
>>> should_split([10,0,0])
False
>>> should_split([1,.1,.2])
True
"""
return max(split) / sum(split) != 1.
def get_ext(path):
"""gets path extension"""
return os.path.splitext(path)[1]
def get_dataset(name, tokenizer, pre_tokenize, local_rank):
"""gets dataset object based on keyword args and file at `path`"""
if supported_corpus(name):
dataset = corpora.NAMED_CORPORA[name]
path = dataset.PATH
else:
dataset = corpora.get_finetune_dataset(name)
path = dataset.PATH
if issubclass(dataset, corpora.PromptReader):
if not (exists_lazy(path, data_type='prompt') and exists_lazy(path, data_type='text')):
# create cached version of dataset for lazy loading if it doesn't exist
if local_rank == 0:
prompt_writer = LazyWriter(path, data_type='prompt', is_array=pre_tokenize)
text_writer = LazyWriter(path, data_type='text', is_array=pre_tokenize)
writers = {'prompt': prompt_writer, 'text': text_writer}
dataset(writers=writers, tokenizer=tokenizer, tokenize=pre_tokenize)
prompt_writer.close()
text_writer.close()
else:
while not os.path.exists(LazyWriter.get_len_path(path, data_type='prompt')):
time.sleep(1)
map_fn = (lambda x: x.tolist()) if pre_tokenize else None
prompts = LazyLoader(path, data_type='prompt', map_fn=map_fn, mem_map=True,
is_array=pre_tokenize)
texts = LazyLoader(path, data_type='text', map_fn=map_fn, mem_map=True,
is_array=pre_tokenize)
text = corpora.PromptDataset(prompt_loader=prompts, text_loader=texts, tokenizer=tokenizer,
to_tokenize=not pre_tokenize)
if torch.distributed.get_rank() == 0:
print(f"Create dataset {name} with {len(text)} documents")
for i in range(10):
rand_id = i if i < 5 else random.randrange(len(text))
sample_tokens = text[rand_id]['tokens'][:1024]
print(sample_tokens)
print(tokenizer.DecodeIds(sample_tokens))
return text
elif issubclass(dataset, corpora.KeyReader):
if not (exists_lazy(path, data_type='text') and exists_lazy(path, data_type='mask')):
# create cached version of dataset for lazy loading if it doesn't exist
if local_rank == 0:
text_writer = LazyWriter(path, data_type='text', is_array=pre_tokenize)
mask_writer = LazyWriter(path, data_type='mask', is_array=True)
writers = {'mask': mask_writer, 'text': text_writer}
dataset(writers=writers, tokenizer=tokenizer, tokenize=pre_tokenize)
mask_writer.close()
text_writer.close()
else:
while not os.path.exists(LazyWriter.get_len_path(path, data_type='mask')):
time.sleep(1)
map_fn = (lambda x: x.tolist()) if pre_tokenize else None
masks = LazyLoader(path, data_type='mask', map_fn=map_fn, mem_map=True, is_array=True)
texts = LazyLoader(path, data_type='text', map_fn=map_fn, mem_map=True, is_array=pre_tokenize)
text = corpora.KeyDataset(mask_loader=masks, text_loader=texts, tokenizer=tokenizer,
to_tokenize=not pre_tokenize)
return text
else:
raise NotImplementedError
def supported_corpus(corpus_name):
"""checks if corpus name is defined in `corpora.py`"""
return corpus_name in corpora.NAMED_CORPORA
def make_dataset(path, seq_length, mem_length, local_rank, lazy=False, xl_style=False,
shuffle=True, split=None, tokenizer=None, tokenizer_type='CharacterLevelTokenizer',
tokenizer_model_path=None, vocab_size=None, model_type='bpe', pad_token=0, character_converage=1.0,
non_binary_cols=None, sample_one_document=False, pre_tokenize=False, **kwargs):
"""function to create datasets+tokenizers for common options"""
if split is None:
split = [1.]
if non_binary_cols is not None:
# multilabel dataset support (only for csvs)
label_key = non_binary_cols
# make tokenizer for dataset
if tokenizer is None:
add_eop = path.endswith('key') if isinstance(path, str) else sum([p.endswith('key') for p in path]) > 0
tokenizer = make_tokenizer(tokenizer_type, None, tokenizer_model_path, vocab_size, model_type,
pad_token, character_converage, add_eop=add_eop, **kwargs)
# get one or multiple datasets and concatenate
if isinstance(path, str):
ds = get_dataset(path, tokenizer=tokenizer, pre_tokenize=pre_tokenize, local_rank=local_rank)
else:
ds = [get_dataset(p, tokenizer=tokenizer, pre_tokenize=pre_tokenize, local_rank=local_rank) for p in path]
ds = ConcatDataset(ds)
ds_type = ''
if 'ds_type' in kwargs:
ds_type = kwargs['ds_type']
# Split dataset into train/val/test (and wrap bert dataset)
if should_split(split):
ds = split_ds(ds, split, shuffle=shuffle)
if ds_type.lower() == 'bert':
presplit_sentences = kwargs['presplit_sentences'] if 'presplit_sentences' in kwargs else False
ds = [bert_sentencepair_dataset(d, max_seq_len=seq_length,
presplit_sentences=presplit_sentences) if d is not None else None for d in
ds]
elif ds_type.lower() == 'gpt2':
if xl_style:
ds = [XLDataset(d, tokenizer, max_seq_len=seq_length, mem_len=mem_length,
sample_across_doc=not sample_one_document) if d is not None else None for d in ds]
else:
ds = [GPT2Dataset(d, tokenizer, max_seq_len=seq_length,
sample_across_doc=not sample_one_document) if d is not None else None for d in ds]
else:
if ds_type.lower() == 'bert':
presplit_sentences = kwargs['presplit_sentences'] if 'presplit_sentences' in kwargs else False
ds = bert_sentencepair_dataset(ds, max_seq_len=seq_length, presplit_sentences=presplit_sentences)
elif ds_type.lower() == 'gpt2':
if xl_style:
ds = XLDataset(ds, tokenizer, max_seq_len=seq_length, mem_len=mem_length,
sample_across_doc=not sample_one_document)
else:
ds = GPT2Dataset(ds, tokenizer, max_seq_len=seq_length, sample_across_doc=not sample_one_document)
return ds, tokenizer
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""several datasets with preset arguments"""
from .datasets import json_dataset, csv_dataset
import os
import re
import json
import csv
import tqdm
from multiprocessing import Queue, Process
from torch.utils import data
from .lazy_loader import LazyLoader
NUM_PROCESSES = 40
class webtext(json_dataset):
"""
dataset for webtext with arguments configured for convenience
command line usage: `--train-data webtext`
"""
PATH = 'data/webtext/data.json'
assert_str = "make sure to set PATH for webtext data_utils/corpora.py"
def __init__(self, **kwargs):
assert os.path.exists(webtext.PATH), \
webtext.assert_str
if not kwargs:
kwargs = {}
kwargs['text_key'] = 'text'
kwargs['loose_json'] = True
super(webtext, self).__init__(webtext.PATH, **kwargs)
class KeyDataset(data.Dataset):
def __init__(self, text_loader, mask_loader, **kwargs):
self.texts = text_loader
self.masks = mask_loader
self.is_lazy = False
if isinstance(self.texts, LazyLoader) and isinstance(self.masks, LazyLoader):
self.text_lens = self.texts.lens
self.is_lazy = True
def get_text_len(self, idx):
return self.text_lens[idx]
def __getitem__(self, index):
text = self.texts[index]
mask_length = self.masks[index]
mask = []
for i, length in enumerate(mask_length):
if i % 2 == 0:
mask += [0] * length
else:
mask += [1] * length
assert len(text) == len(mask)
return {"tokens": text, "loss_masks": mask}
def __len__(self):
return len(self.texts)
class PromptDataset(data.Dataset):
def __init__(self, prompt_loader, text_loader, tokenizer=None, to_tokenize=False, **kwargs):
self.prompts = prompt_loader
self.texts = text_loader
self.tokenizer = tokenizer
self.to_tokenize = to_tokenize
if isinstance(self.prompts, LazyLoader) and isinstance(self.texts, LazyLoader):
self.prompt_lens = self.prompts.lens
self.text_lens = self.texts.lens
self.is_lazy = True
def get_text_len(self, idx):
return self.prompt_lens[idx] + self.text_lens[idx]
def __getitem__(self, index):
prompt = self.prompts[index]
text = self.texts[index]
if self.to_tokenize:
prompt = self.tokenizer.EncodeAsIds(prompt).tokenization
text = self.tokenizer.EncodeAsIds(text).tokenization
return {"tokens": prompt + text, "loss_masks": [0] * len(prompt) + [1] * len(text)}
def __len__(self):
return len(self.prompts)
class DataReader:
is_json = True
PATH = None
assert_str = None
@classmethod
def tokenize_worker(cls, input, output, reader, tokenizer, tokenize):
raise NotImplementedError
def __init__(self, writers, tokenizer=None, tokenize=False, **kwargs):
assert os.path.exists(self.PATH), self.assert_str
self.tokenizer = tokenizer
self.tokenize = tokenize
self.writers = writers
if os.path.isdir(self.PATH):
paths = [entry.path for entry in os.scandir(self.PATH) if
not entry.is_dir() and not entry.name.endswith("bz2")]
else:
paths = [self.PATH]
task_queue, done_queue = Queue(), Queue()
processes = []
for i in range(NUM_PROCESSES):
process = Process(target=self.tokenize_worker,
args=(task_queue, done_queue, type(self), tokenizer, tokenize))
process.start()
processes.append(process)
csv.field_size_limit(10000000)
def read_input_to_queue():
for path in paths:
print(f"Start reading {path}")
with open(path) as file:
if self.is_json:
for row in tqdm.tqdm(file):
task_queue.put(row)
else:
reader = csv.reader(file)
for item in reader:
task_queue.put(item)
for _ in range(len(processes)):
task_queue.put('STOP')
process = Process(target=read_input_to_queue)
process.start()
count = len(processes)
progress_bar = tqdm.tqdm()
while True:
data = done_queue.get()
if data == 'COMPLETE':
count -= 1
if count == 0:
break
else:
self.write_result(data, self.writers)
progress_bar.update()
progress_bar.close()
@staticmethod
def write_result(data, writers):
raise NotImplementedError
@staticmethod
def get_token_count(contents):
return sum(map(len, contents))
@staticmethod
def process_sample(text, tokenizer, tokenize):
if isinstance(text, str) and tokenize:
text = tokenizer.EncodeAsIds(text).tokenization if text else []
return text
@staticmethod
def trim_field(content, max_length):
if len(content) > max_length:
content = content[:max_length]
content += "......"
return content
@classmethod
def process_line(cls, data, tokenizer, tokenize):
raise NotImplementedError
class PromptReader(DataReader):
@classmethod
def tokenize_worker(cls, input, output, reader, tokenizer, tokenize):
for row in iter(input.get, 'STOP'):
if cls.is_json:
item = json.loads(row)
else:
item = row
prompts, texts = reader.process_line(item, tokenizer, tokenize)
for prompt, text in zip(prompts, texts):
output.put((prompt, text))
output.put("COMPLETE")
@staticmethod
def write_result(data, writers):
prompt, text = data
writers['prompt'].write(prompt)
writers['text'].write(text)
class KeyReader(DataReader):
PATH = '/root/data/wikipedia/wiki-key.txt'
assert_str = "make sure to set PATH for wikipedia data_utils/corpora.py"
@classmethod
def process_line(cls, data, tokenizer, tokenize):
keys, contents = data['key'], data["content"]
assert len(keys) == len(contents)
for i in range(1, len(keys)):
keys[i] = " " + keys[i]
contents = [" " + content for content in contents]
keys = [tokenizer.EncodeAsIds(key).tokenization for key in keys]
contents = [tokenizer.EncodeAsIds(content).tokenization for content in contents]
summary = sum(keys, [])
summary_prefix = cls.process_sample("Summary: ", tokenizer, tokenize)
summary_mask = [len(summary_prefix), len(summary)]
summary = summary_prefix + summary
text, text_mask = [], []
for key, content in zip(keys, contents):
content = content + [tokenizer.get_command('eop').Id]
text += key
text += content
text_mask.append(len(key))
text_mask.append(len(content))
return (summary, summary_mask), (text, text_mask)
@classmethod
def tokenize_worker(cls, input, output, reader, tokenizer, tokenize):
for row in iter(input.get, 'STOP'):
data = json.loads(row)
summary, content = reader.process_line(data, tokenizer, tokenize)
output.put((summary, content))
output.put("COMPLETE")
@staticmethod
def write_result(data, writers):
summary, content = data
writers['text'].write(summary[0])
writers['mask'].write(summary[1])
writers['text'].write(content[0])
writers['mask'].write(content[1])
class zhihu(PromptReader):
PATH = "/root/data/zhihu/zhihu"
# PATH = "data/zhihu/data.json"
assert_str = "make sure to set PATH for zhihu data_utils/corpora.py"
qtitle_prefix = "问题:"
qcontent_prefix = "问题描述:"
user_prefix = "回答用户:"
answer_prefix = " 回答:"
# qtitle_prefix = []
# qcontent_prefix = []
# user_prefix = []
# answer_prefix = []
@classmethod
def process_line(cls, data, tokenizer, tokenize):
prompts, texts = [], []
ans_length = len(data.get("ans-content", ""))
ans_up = data.get("ans-up-num", "")
ans_up = int(ans_up) if ans_up else 0
if ans_length > 100 or ans_up > 1000:
qtitle = data["q_title"]
qcontent = data["q-content"]
if qcontent is None:
qcontent = ""
qcontent = cls.trim_field(qcontent, max_length=100)
user = data.get("user-signature", "")
prompt = cls.qtitle_prefix + qtitle + cls.qcontent_prefix + qcontent + cls.user_prefix + user + cls.answer_prefix
text = data["ans-content"]
prompt, text = cls.process_sample(prompt, tokenizer, tokenize), cls.process_sample(text, tokenizer,
tokenize)
prompts.append(prompt)
texts.append(text)
# prompt = data["q_title"] + data["q-content"] + data["user-signature"]
# text = data["ans-content"]
# prompts.append(prompt)
# texts.append(text)
return prompts, texts
class zhidao(PromptReader):
PATH = "/root/data/zhidao/zhidao"
assert_str = "make sure to set PATH for zhidao data_utils/corpora.py"
qtitle_prefix = "问题:"
qcontent_prefix = "问题描述:"
answer_prefix = "回答:"
@classmethod
def process_line(cls, data, tokenizer, tokenize):
if "title" not in data:
return [], []
prompts, texts = [], []
qtitle = data["title"]
qcontent = data.get("content", "")
qcontent = cls.trim_field(qcontent, max_length=100)
prompt = cls.qtitle_prefix + qtitle + cls.qcontent_prefix + qcontent + cls.answer_prefix
prompt = cls.process_sample(prompt, tokenizer, tokenize)
if "best_answer" in data:
text = data["best_answer"]["content"]
if len(text) > 10:
text = cls.process_sample(text, tokenizer, tokenize)
prompts.append(prompt)
texts.append(text)
for answer in data.get("other_answers", []):
text = answer["content"]
if len(text) > 100:
text = cls.process_sample(text, tokenizer, tokenize)
prompts.append(prompt)
texts.append(text)
return prompts, texts
class baike(PromptReader):
PATH = "/root/data/baike/baike"
assert_str = "make sure to set PATH for baike data_utils/corpora.py"
@classmethod
def process_line(cls, data, tokenizer, tokenize):
prompts, texts = [], []
text = data.get("title", "") + data.get("abstract", "") + data.get("content", "")
if text:
p, t = cls.process_sample("", tokenizer, tokenize), cls.process_sample(text, tokenizer, tokenize)
prompts.append(p)
texts.append(t)
return prompts, texts
class wikipedia(PromptReader):
"""
dataset for wikipedia with arguments configured for convenience
command line usage: `--train-data wikipedia`
"""
# PATH = '/dataset/data/wiki.txt'
PATH = '/root/data/wiki.txt'
assert_str = "make sure to set PATH for wikipedia data_utils/corpora.py"
@classmethod
def process_line(cls, data, tokenizer, tokenize):
text = data['text']
prompt, text = cls.process_sample("", tokenizer, tokenize), cls.process_sample(text, tokenizer, tokenize)
return [prompt], [text]
class XinhuaNews(PromptReader):
is_json = False
PATH = '/root/data/xinhua'
assert_str = "make sure to set PATH for xinhua data_utils/corpora.py"
@classmethod
def process_line(cls, item, tokenizer, tokenize):
title, content = item
def clean_html(raw_html):
pattern = r'<.*?>'
clean_text = re.sub(pattern, '', raw_html)
return clean_text
def clean_text(text):
text = text.replace("\u3000", " ")
return text
prompts, texts = [], []
title = clean_text(clean_html(title))
content = clean_text(clean_html(content))
if len(content) > 100:
prompts.append("标题:" + title + " 内容:")
texts.append(content)
if len(title) > 1:
prompts.append("内容:")
texts.append(content + " 标题:" + title)
prompts = [cls.process_sample(prompt, tokenizer, tokenize) for prompt in prompts]
texts = [cls.process_sample(content, tokenizer, tokenize) for text in texts]
return prompts, texts
def get_finetune_dataset(path):
class FinetuneDataset(PromptReader):
PATH = path
assert_str = f"File not exists at {PATH}"
@classmethod
def process_line(cls, data, tokenizer, tokenize):
prompt, text = data['prompt'], data['text']
prompt, text = cls.process_sample(prompt, tokenizer, tokenize), cls.process_sample(text, tokenizer,
tokenize)
return [prompt], [text]
return FinetuneDataset
class CMRC(PromptReader):
PATH = "/raid/zengzhongshen/sim_text_gen/cmrc2018/squad-style-data/train_processed.json"
assert_str = "必须在data_utils/corpora.py正确初始化CMRC数据集"
qtitle_prefix = "问题:"
qcontent_prefix = "问题描述:"
user_prefix = "回答用户:"
answer_prefix = " 回答:"
@classmethod
def process_line(cls, data, tokenizer, tokenize):
pass
class SimTextDataset(PromptReader):
PATH = '/raid/zengzhongshen/sim_text_gen/transformer_xl_chinese/sim_text_data/processed_pretrain_text.jsonl'
assert_str = "make sure to set PATH for wikipedia data_utils/corpora.py"
sim_text_prompt="相似句:"
@classmethod
def process_line(cls, data, tokenizer, tokenize):
orig_text, sim_text = data['orig_text'], data['sim_text']
orig_text, sim_text = cls.process_sample(orig_text, tokenizer, tokenize), cls.process_sample(cls.sim_text_prompt+sim_text, tokenizer,
tokenize)
return [orig_text], [sim_text]
NAMED_CORPORA = {
'wikipedia': wikipedia,
'wikipedia-key': KeyReader,
'webtext': webtext,
"zhihu": zhihu,
"zhidao": zhidao,
"baike": baike,
"xinhua": XinhuaNews,
"simtext": SimTextDataset
}
# coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""dataset objects for jsons, csvs, and BERT datasets"""
import os
import time
from operator import itemgetter
from bisect import bisect_right
import json
import csv
import math
import random
import tqdm
from itertools import accumulate
from torch.utils import data
import pandas as pd
import numpy as np
import nltk
from nltk import tokenize
from .lazy_loader import LazyLoader, exists_lazy
from .tokenization import Tokenization
class ShuffleDataset(data.Dataset):
def __init__(self, ds):
self.ds = ds
self.shuffle_ids = list(range(len(self.ds)))
random.shuffle(self.shuffle_ids)
self.is_lazy = hasattr(ds, 'is_lazy') and ds.is_lazy
if self.is_lazy:
self.prompt_lens = [self.ds.prompt_lens[idx] for idx in self.shuffle_ids]
self.text_lens = [self.ds.text_lens[idx] for idx in self.shuffle_ids]
def __getitem__(self, idx):
return self.ds[self.shuffle_ids[idx]]
def __len__(self):
return len(self.ds)
class ConcatDataset(data.Dataset):
"""
Dataset to concatenate multiple datasets.
Purpose: useful to assemble different existing datasets, possibly
large-scale datasets as the concatenation operation is done in an
on-the-fly manner.
Arguments:
datasets (sequence): List of datasets to be concatenated.
"""
@staticmethod
def cumsum(sequence):
r, s = [], 0
for e in sequence:
l = len(e)
r.append(l + s)
s += l
return r
def __init__(self, datasets, **kwargs):
super(ConcatDataset, self).__init__()
assert len(datasets) > 0, 'datasets should not be an empty iterable'
self.datasets = list(datasets)
self.is_lazy = sum([isinstance(ds, LazyLoader) or (hasattr(ds, 'is_lazy') and ds.is_lazy) for ds in
self.datasets]) == len(self.datasets)
self.cumulative_sizes = self.cumsum(self.datasets)
self._X = None
self._Y = None
self._lens = None
def get_text_len(self, idx):
dataset_idx = bisect_right(self.cumulative_sizes, idx)
if dataset_idx == 0:
sample_idx = idx
else:
sample_idx = idx - self.cumulative_sizes[dataset_idx - 1]
return self.datasets[dataset_idx].get_text_len(sample_idx)
def SetTokenizer(self, tokenizer):
for ds in self.datasets:
ds.SetTokenizer(tokenizer)
def GetTokenizer(self):
return self.datasets[0].GetTokenizer()
def __len__(self):
return self.cumulative_sizes[-1]
def __getitem__(self, idx):
dataset_idx = bisect_right(self.cumulative_sizes, idx)
if dataset_idx == 0:
sample_idx = idx
else:
sample_idx = idx - self.cumulative_sizes[dataset_idx - 1]
return self.datasets[dataset_idx][sample_idx]
@property
def lens(self):
if self._lens is None:
self._lens = []
if self.is_lazy:
for data in self.datasets:
self._lens.extend(data.lens)
else:
for data in self.datasets:
self._lens.extend([len(d['text']) if isinstance(d, dict) else len(d) for d in data])
return self._lens
@property
def X(self):
if self._X is None:
self._X = []
for data in self.datasets:
self._X.extend(data.X)
return self._X
@property
def Y(self):
if self._Y is None:
self._Y = []
for data in self.datasets:
self._Y.extend(list(data.Y))
self._Y = np.array(self._Y)
return self._Y
class SplitDataset(data.Dataset):
"""
Dataset wrapper to access a subset of another dataset.
Purpose: useful to index into existing datasets, possibly
large-scale datasets as the subindexing operation is done in an
on-the-fly manner.
Arguments:
ds (Dataset or array-like): List of datasets to be subindexed
split_inds (1D array-like): List of indices part of subset
"""
def __init__(self, ds, split_inds, **kwargs):
self.split_inds = list(split_inds)
self.wrapped_data = ds
self.is_lazy = isinstance(ds, LazyLoader) or (hasattr(ds, 'is_lazy') and ds.is_lazy)
self._X = None
self._Y = None
def __len__(self):
return len(self.split_inds)
def get_text_len(self, idx):
return self.wrapped_data.get_text_len(self.split_inds[idx])
def __getitem__(self, index):
return self.wrapped_data[self.split_inds[index]]
def SetTokenizer(self, tokenizer):
self.wrapped_data.SetTokenizer(tokenizer)
def GetTokenizer(self):
return self.wrapped_data.GetTokenizer()
@property
def X(self):
if self._X is None:
self._X = itemgetter(*self.split_inds)(self.wrapped_data.X)
return self._X
@property
def Y(self):
if self._Y is None:
self._Y = np.array(itemgetter(*self.split_inds)(self.wrapped_data.Y))
return self._Y
def __iter__(self):
for idx in self.split_inds:
yield self.wrapped_data[idx]
def split_ds(ds, split=[.8,.2,.0], shuffle=True):
"""
Split a dataset into subsets given proportions of how
much to allocate per split. If a split is 0% returns None for that split.
Purpose: Useful for creating train/val/test splits
Arguments:
ds (Dataset or array-like): Data to be split.
split (1D array-like): proportions to split `ds`. `sum(splits) != 0`
shuffle (boolean): Randomly split dataset. Default: True
"""
split_sum = sum(split)
if split_sum == 0:
raise Exception('Split cannot sum to 0.')
split = np.array(split)
split /= split_sum
ds_len = len(ds)
inds = np.arange(ds_len)
if shuffle:
np.random.shuffle(inds)
start_idx = 0
residual_idx = 0
rtn_ds = [None]*len(split)
for i, f in enumerate(split):
if f != 0:
proportion = ds_len*split[i]
residual_idx += proportion % 1
split_ = int(int(proportion) + residual_idx)
split_inds = inds[start_idx:start_idx+max(split_, 1)]
rtn_ds[i] = SplitDataset(ds, split_inds)
start_idx += split_
residual_idx %= 1
return rtn_ds
class csv_dataset(data.Dataset):
"""
Class for loading datasets from csv files.
Purpose: Useful for loading data for unsupervised modeling or transfer tasks
Arguments:
path (str): Path to csv file with dataset.
tokenizer (data_utils.Tokenizer): Tokenizer to use when processing text. Default: None
preprocess_fn (callable): Callable that process a string into desired format.
delim (str): delimiter for csv. Default: ','
binarize_sent (bool): binarize label values to 0 or 1 if they\'re on a different scale. Default: False
drop_unlabeled (bool): drop rows with unlabelled values. Always fills remaining empty
columns with -1 (regardless if rows are dropped based on value) Default: False
text_key (str): key to get text from csv. Default: 'sentence'
label_key (str): key to get label from json dictionary. Default: 'label'
Attributes:
X (list): all strings from the csv file
Y (np.ndarray): labels to train with
"""
def __init__(self, path, tokenizer=None, preprocess_fn=None, delim=',',
binarize_sent=False, drop_unlabeled=False, text_key='sentence', label_key='label',
**kwargs):
self.is_lazy = False
self.preprocess_fn = preprocess_fn
self.SetTokenizer(tokenizer)
self.path = path
self.delim = delim
self.text_key = text_key
self.label_key = label_key
self.drop_unlabeled = drop_unlabeled
if '.tsv' in self.path:
self.delim = '\t'
self.X = []
self.Y = []
try:
cols = [text_key]
if isinstance(label_key, list):
cols += label_key
else:
cols += [label_key]
data = pd.read_csv(self.path, sep=self.delim, usecols=cols, encoding='latin-1')
except:
data = pd.read_csv(self.path, sep=self.delim, usecols=[text_key], encoding='latin-1')
data = data.dropna(axis=0)
self.X = data[text_key].values.tolist()
try:
self.Y = data[label_key].values
except Exception as e:
self.Y = np.ones(len(self.X))*-1
if binarize_sent:
self.Y = binarize_labels(self.Y, hard=binarize_sent)
def SetTokenizer(self, tokenizer):
if tokenizer is None:
self.using_tokenizer = False
if not hasattr(self, '_tokenizer'):
self._tokenizer = tokenizer
else:
self.using_tokenizer = True
self._tokenizer = tokenizer
def GetTokenizer(self):
return self._tokenizer
@property
def tokenizer(self):
if self.using_tokenizer:
return self._tokenizer
return None
def __len__(self):
return len(self.X)
def __getitem__(self, index):
"""process+tokenize string and return string,label,and stringlen"""
x = self.X[index]
if self.tokenizer is not None:
x = self.tokenizer.EncodeAsIds(x, self.preprocess_fn)
elif self.preprocess_fn is not None:
x = self.preprocess_fn(x)
y = self.Y[index]
if isinstance(y, str):
if self.tokenizer is not None:
y = self.tokenizer.EncodeAsIds(y, self.preprocess_fn)
elif self.preprocess_fn is not None:
y = self.preprocess_fn(y)
return {'text': x, 'length': len(x), 'label': y}
def write(self, writer_gen=None, path=None, skip_header=False):
"""
given a generator of metrics for each of the data points X_i,
write the metrics, text, and labels to a csv file
"""
if path is None:
path = self.path+'.results'
print('generating csv at ' + path)
with open(path, 'w') as csvfile:
c = csv.writer(csvfile, delimiter=self.delim)
if writer_gen is not None:
#if first item of generator is a header of what the metrics mean then write header to csv file
if not skip_header:
header = (self.label_key,)+tuple(next(writer_gen))+(self.text_key,)
c.writerow(header)
for i, row in enumerate(writer_gen):
row = (self.Y[i],)+tuple(row)+(self.X[i],)
c.writerow(row)
else:
c.writerow([self.label_key, self.text_key])
for row in zip(self.Y, self.X):
c.writerow(row)
class json_dataset(data.Dataset):
"""
Class for loading datasets from a json dump.
Purpose: Useful for loading data for unsupervised modeling or transfer tasks
Arguments:
path (str): path to json file with dataset.
tokenizer (data_utils.Tokenizer): Tokenizer to use when processing text. Default: None
preprocess_fn (callable): callable function that process a string into desired format.
Takes string, maxlen=None, encode=None as arguments. Default: process_str
text_key (str): key to get text from json dictionary. Default: 'sentence'
label_key (str): key to get label from json dictionary. Default: 'label'
Attributes:
all_strs (list): list of all strings from the dataset
all_labels (list): list of all labels from the dataset (if they have it)
"""
def __init__(self, path, tokenizer=None, preprocess_fn=None, binarize_sent=False,
text_key='sentence', label_key='label', loose_json=False, **kwargs):
self.is_lazy = False
self.preprocess_fn = preprocess_fn
self.path = path
self.SetTokenizer(tokenizer)
self.X = []
self.Y = []
self.text_key = text_key
self.label_key = label_key
self.loose_json = loose_json
for j in self.load_json_stream(self.path):
s = j[text_key]
self.X.append(s)
self.Y.append(j[label_key])
if binarize_sent:
self.Y = binarize_labels(self.Y, hard=binarize_sent)
def SetTokenizer(self, tokenizer):
if tokenizer is None:
self.using_tokenizer = False
if not hasattr(self, '_tokenizer'):
self._tokenizer = tokenizer
else:
self.using_tokenizer = True
self._tokenizer = tokenizer
def GetTokenizer(self):
return self._tokenizer
@property
def tokenizer(self):
if self.using_tokenizer:
return self._tokenizer
return None
def __getitem__(self, index):
"""gets the index'th string from the dataset"""
x = self.X[index]
if self.tokenizer is not None:
x = self.tokenizer.EncodeAsIds(x, self.preprocess_fn)
elif self.preprocess_fn is not None:
x = self.preprocess_fn(x)
y = self.Y[index]
if isinstance(y, str):
if self.tokenizer is not None:
y = self.tokenizer.EncodeAsIds(y, self.preprocess_fn)
elif self.preprocess_fn is not None:
y = self.preprocess_fn(y)
return {'text': x, 'length': len(x), 'label': y}
def __len__(self):
return len(self.X)
def write(self, writer_gen=None, path=None, skip_header=False):
"""
given a generator of metrics for each of the data points X_i,
write the metrics, text, and labels to a json file
"""
if path is None:
path = self.path+'.results'
jsons = []
if writer_gen is not None:
#if first item of generator is a header of what the metrics mean then write header to csv file
def gen_helper():
keys = {}
keys[0] = self.label_key
if not skip_header:
for idx, k in enumerate(tuple(next(writer_gen))):
keys[idx+1] = k
for i, row in enumerate(writer_gen):
if i == 0 and skip_header:
for idx, _ in enumerate(row):
keys[idx+1] = 'metric_%d'%(idx,)
j = {}
for idx, v in enumerate((self.Y[i],)+tuple(row)):
k = keys[idx]
j[k] = v
yield j
else:
def gen_helper():
for y in self.Y:
j = {}
j[self.label_key] = y
yield j
def out_stream():
for i, j in enumerate(gen_helper()):
j[self.text_key] = self.X[i]
yield j
self.save_json_stream(path, out_stream())
def save_json_stream(self, save_path, json_stream):
if self.loose_json:
with open(save_path, 'w') as f:
for i, j in enumerate(json_stream):
write_string = ''
if i != 0:
write_string = '\n'
write_string += json.dumps(j)
f.write(write_string)
else:
jsons = [j for j in json_stream]
json.dump(jsons, open(save_path, 'w'), separators=(',', ':'))
def load_json_stream(self, load_path):
if not self.loose_json:
jsons = json.load(open(load_path, 'r'))
generator = iter(jsons)
else:
def gen_helper():
with open(load_path, 'r') as f:
for row in f:
yield json.loads(row)
generator = gen_helper()
for j in generator:
if self.label_key not in j:
j[self.label_key] = -1
yield j
class XLDataset(data.Dataset):
def __init__(self, ds, tokenizer, max_seq_len=1024, mem_len=None, sample_across_doc=True, **kwargs):
self.ds = ds
self.tokenizer = tokenizer
self.max_seq_len = max_seq_len
if mem_len is None:
mem_len = max_seq_len
self.mem_len = mem_len
self.sample_across_doc = sample_across_doc
self.indices, self.num_samples = None, None
if hasattr(self.ds, 'is_lazy') and self.ds.is_lazy:
self.is_lazy = True
self.init_indices()
def init_indices(self):
if self.is_lazy:
lens = np.array([self.ds.get_text_len(idx) for idx in range(len(self.ds))])
else:
lens = np.array([len(d['prompt']) + len(d['text']) if isinstance(d, dict) else len(d) for d in self.ds])
self.indices = list(accumulate(lens))
self.num_samples = self.indices[-1] // self.max_seq_len + 1
def __len__(self):
return self.num_samples
def __getitem__(self, idx):
tokens, targets, loss_mask, attention_mask = self.getidx(idx)
tokens = self.pad_seq(tokens)
targets = self.pad_seq(targets)
loss_mask = self.pad_seq(loss_mask, pad_id=0)
return {'text': np.array(tokens), "target": np.array(targets), "loss_mask": np.array(loss_mask),
"attention_mask": np.array(attention_mask)}
def getidx(self, idx):
tokens, targets, loss_masks = [], [], []
attention_mask = np.concatenate((np.zeros((self.max_seq_len, self.mem_len), dtype=np.long),
np.ones((self.max_seq_len, self.max_seq_len), dtype=np.long)), axis=1)
sample_idx = bisect_right(self.indices, idx * self.max_seq_len)
last_end = 0 if sample_idx == 0 else self.indices[sample_idx - 1]
token_offset = idx * self.max_seq_len - last_end
if token_offset != 0:
history = min(self.mem_len, token_offset)
attention_mask[:, -self.max_seq_len-history:-self.max_seq_len] = 1
count = 0
while len(tokens) < self.max_seq_len and sample_idx < len(self.ds):
item = self.ds[sample_idx]
text, masks = item['tokens'], item['loss_masks']
text = text + [self.tokenizer.get_command('eos').Id]
end = min(len(text) - 1, token_offset + self.max_seq_len - len(tokens))
masks = masks + [1]
if count > 0:
current = len(tokens)
attention_mask[current:, :current + self.mem_len] = 0
tokens += text[token_offset: end]
targets += text[token_offset + 1: end + 1]
loss_masks += masks[token_offset + 1: end + 1]
count += 1
sample_idx += 1
token_offset = 0
return tokens, targets, loss_masks, attention_mask
def pad_seq(self, seq, pad_id=None):
total_tokens = self.max_seq_len
num_pad_tokens = max(0, total_tokens - len(seq))
seq += [self.tokenizer.get_command('pad').Id if pad_id is None else pad_id]*(num_pad_tokens)
return seq
class GPT2Dataset(data.Dataset):
def __init__(self, ds, tokenizer,
max_seq_len=1024,
num_samples=None,
weighted=True,
sample_across_doc=True,
random_across_doc_sampling=True,
sentence_start=False, **kwargs):
"""
sentence_start: the stripped article must start with a complete sentence
"""
self.ds = ds
self.ds_len = len(self.ds)
self.num_samples = num_samples
if num_samples is None:
self.num_samples = 1000 * self.ds_len
self.max_seq_len = max_seq_len
self.tokenizer = tokenizer
self.weighted = weighted
self.sample_across_doc = sample_across_doc
self.random_across_doc_sampling = random_across_doc_sampling
self.sentence_start = sentence_start
self.weighting, self.total_len = None, None
self.is_lazy = False
if hasattr(self.ds, 'is_lazy') and self.ds.is_lazy:
self.is_lazy = True
self.init_weighting()
def init_weighting(self):
if self.weighted:
if self.is_lazy:
lens = np.array([self.ds.get_text_len(idx) for idx in range(len(self.ds))])
else:
lens = np.array([len(d['text']) if isinstance(d, dict)
else len(d) for d in self.ds])
self.total_len = np.sum(lens)
self.weighting = list(accumulate(lens))
else:
self.weighting = None
def get_weighted_samples(self, np_rng):
if self.weighting is not None:
idx = np_rng.randint(self.total_len)
return bisect_right(self.weighting, idx)
else:
return np_rng.randint(self.ds_len)
def __len__(self):
return self.num_samples
def __getitem__(self, idx):
# init rng
rng = random.Random(idx)
rng = np.random.RandomState(seed=[rng.randint(0, 2**32-1) for _ in range(16)])
# get possibly weighted random index from dataset
data_idx = self.get_weighted_samples(rng)
# data_idx = rng.choice(self.ds_len, p=self.weighting)
tokens, loss_mask = self.getidx(data_idx)
# truncate or pad tokens
num_tokens = len(tokens)
tokens_to_strip = num_tokens - self.max_seq_len - 1
# randomly choose a position for start
if tokens_to_strip > 0:
strip_left_tokens = rng.randint(tokens_to_strip + 1)
tokens = tokens[strip_left_tokens:]
loss_mask = loss_mask[strip_left_tokens:]
# if self.sentence_start:
# token_copy = list(tokens)
# not_done = True
# while (len(token_copy) > 0) and not_done:
# tok = token_copy.pop(0)
# if self.contains_sentence_end(tok):
# tokens = token_copy
# not_done = False
strip_right_rokens = len(tokens) - self.max_seq_len - 1
if strip_right_rokens > 0:
tokens = tokens[:-strip_right_rokens]
loss_mask = loss_mask[:-strip_right_rokens]
# Sample multiple documents
if self.sample_across_doc:
while (len(tokens) < (self.max_seq_len + 1)):
if self.random_across_doc_sampling:
data_idx = self.get_weighted_samples(rng)
else:
data_idx = (data_idx + 1) % self.ds_len
new_tokens, new_loss_mask = self.getidx(data_idx)
tokens += new_tokens
loss_mask += new_loss_mask
tokens = tokens[:(self.max_seq_len+1)]
loss_mask = loss_mask[:(self.max_seq_len + 1)]
tokens = self.pad_seq(tokens)
loss_mask = self.pad_seq(loss_mask, pad_id=0)
return {'text': np.array(tokens), "loss_mask": np.array(loss_mask)}
def getidx(self, data_idx):
data = self.ds[data_idx]
tokens, loss_masks = data['tokens'], data['loss_masks']
tokens = tokens + [self.tokenizer.get_command('eos').Id]
loss_masks = loss_masks + [1]
return tokens, loss_masks
def pad_seq(self, seq, pad_id=None):
total_tokens = self.max_seq_len + 1
num_pad_tokens = max(0, total_tokens - len(seq))
seq += [self.tokenizer.get_command('pad').Id if pad_id is None else pad_id]*(num_pad_tokens)
return seq
# TODO: rewrite this function for chinese
def contains_sentence_end(self, tok):
tok = self.tokenizer.IdToToken(tok)
if '.' in tok:
return True
if '?' in tok:
return True
if '!' in tok:
return True
return False
class bert_sentencepair_dataset(data.Dataset):
"""
Dataset containing sentencepairs for BERT training. Each index corresponds to a randomly generated sentence pair.
Arguments:
ds (Dataset or array-like): data corpus to use for training
max_seq_len (int): maximum sequence length to use for a sentence pair
mask_lm_prob (float): proportion of tokens to mask for masked LM
max_preds_per_seq (int): Maximum number of masked tokens per sentence pair. Default: math.ceil(max_seq_len*mask_lm_prob/10)*10
short_seq_prob (float): Proportion of sentence pairs purposefully shorter than max_seq_len
dataset_size (int): number of random sentencepairs in the dataset. Default: len(ds)*(len(ds)-1)
"""
def __init__(self, ds, max_seq_len=512, mask_lm_prob=.15, max_preds_per_seq=None, short_seq_prob=.01, dataset_size=None, presplit_sentences=False, weighted=True,**kwargs):
self.ds = ds
self.ds_len = len(self.ds)
self.tokenizer = self.ds.GetTokenizer()
self.vocab_words = list(self.tokenizer.text_token_vocab.values())
self.ds.SetTokenizer(None)
self.max_seq_len = max_seq_len
self.mask_lm_prob = mask_lm_prob
if max_preds_per_seq is None:
max_preds_per_seq = math.ceil(max_seq_len*mask_lm_prob /10)*10
self.max_preds_per_seq = max_preds_per_seq
self.short_seq_prob = short_seq_prob
self.dataset_size = dataset_size
if self.dataset_size is None:
self.dataset_size = self.ds_len * (self.ds_len-1)
self.presplit_sentences = presplit_sentences
if not self.presplit_sentences:
nltk.download('punkt', download_dir="./nltk")
self.weighted = weighted
self.get_weighting()
def get_weighting(self):
if self.weighted:
if hasattr(self.ds, 'is_lazy') and self.ds.is_lazy:
lens = np.array(self.ds.lens)
else:
lens = np.array([len(d['text']) if isinstance(d, dict) else len(d) for d in self.ds])
self.total_len = np.sum(lens)
self.weighting = list(accumulate(lens))
else:
self.weighting = None
def get_weighted_samples(self, np_rng):
if self.weighting is not None:
idx = np_rng.randint(self.total_len)
return bisect_right(self.weighting, idx)
else:
return np_rng.randint(self.ds_len)
def __len__(self):
return self.dataset_size
def __getitem__(self, idx):
# get rng state corresponding to index (allows deterministic random pair)
rng = random.Random(idx)
np_rng = np.random.RandomState(seed=[rng.randint(0, 2**32-1) for _ in range(16)])
# get seq length
target_seq_length = self.max_seq_len
short_seq = False
if rng.random() < self.short_seq_prob:
target_seq_length = rng.randint(2, target_seq_length)
short_seq = True
# get sentence pair and label
is_random_next = None
lena = 0
lenb = 0
while (is_random_next is None) or (lena < 1) or (lenb < 1):
tokensa, tokensb, is_random_next = self.create_random_sentencepair(target_seq_length, rng, np_rng)
lena = len(tokensa[0])
lenb = len(tokensb[0])
# truncate sentence pair to max_seq_len
tokensa, tokensb = self.truncate_seq_pair(tokensa, tokensb, self.max_seq_len, rng)
# join sentence pair, mask, and pad
tokens, mask, mask_labels, pad_mask = self.create_masked_lm_predictions(tokensa, tokensb, self.mask_lm_prob, self.max_preds_per_seq, self.vocab_words, rng)
sample = {'text': np.array(tokens[0]), 'types': np.array(tokens[1]), 'is_random': int(is_random_next), 'mask': np.array(mask), 'mask_labels': np.array(mask_labels), 'pad_mask': np.array(pad_mask)}
return sample
def sentence_split(self, document):
"""split document into sentences"""
lines = document.split('\n')
if self.presplit_sentences:
return [line for line in lines if line]
rtn = []
for line in lines:
if line != '':
rtn.extend(tokenize.sent_tokenize(line))
return rtn
def sentence_tokenize(self, sent, sentence_num=0, beginning=False, ending=False):
"""tokenize sentence and get token types"""
tokens = self.tokenizer.EncodeAsIds(sent).tokenization
str_type = 'str' + str(sentence_num)
token_types = [self.tokenizer.get_type(str_type).Id]*len(tokens)
return tokens, token_types
def get_doc(self, idx):
"""gets text of document corresponding to idx"""
rtn = self.ds[idx]
if isinstance(rtn, dict):
rtn = rtn['text']
return rtn
def create_random_sentencepair(self, target_seq_length, rng, np_rng):
"""
fetches a random sentencepair corresponding to rng state similar to
https://github.com/google-research/bert/blob/master/create_pretraining_data.py#L248-L294
"""
is_random_next = None
curr_strs = []
curr_str_types = []
curr_len = 0
while curr_len < 1:
curr_len = 0
doc_a = None
while doc_a is None:
if self.weighted:
# doc_a_idx = np_rng.choice(self.ds_len, p=self.weighting)
doc_a_idx = self.get_weighted_samples(np_rng)
else:
doc_a_idx = rng.randint(0, self.ds_len-1)
doc_a = self.sentence_split(self.get_doc(doc_a_idx))
if not doc_a:
doc_a = None
random_start_a = rng.randint(0, len(doc_a)-1)
while random_start_a < len(doc_a):
sentence = doc_a[random_start_a]
sentence, sentence_types = self.sentence_tokenize(sentence, 0, random_start_a == 0, random_start_a == len(doc_a))
curr_strs.append(sentence)
curr_str_types.append(sentence_types)
curr_len += len(sentence)
if random_start_a == len(doc_a) - 1 or curr_len >= target_seq_length:
break
random_start_a = (random_start_a+1)
if curr_strs:
num_a = 1
if len(curr_strs) >= 2:
num_a = rng.randint(0, len(curr_strs))
tokens_a = []
token_types_a = []
for j in range(num_a):
tokens_a.extend(curr_strs[j])
token_types_a.extend(curr_str_types[j])
tokens_b = []
token_types_b = []
is_random_next = False
if len(curr_strs) == 1 or rng.random() < 0.5:
is_random_next = True
target_b_length = target_seq_length - len(tokens_a)
b_len = 0
while b_len < 1:
doc_b = None
while doc_b is None:
doc_b_idx = rng.randint(0, self.ds_len - 2)
doc_b_idx += int(doc_b_idx >= doc_a_idx)
doc_b = self.sentence_split(self.get_doc(doc_b_idx))
if not doc_b:
doc_b = None
random_start_b = rng.randint(0, len(doc_b)-1)
while random_start_b < len(doc_b):
sentence_b = doc_b[random_start_b]
new_b_tokens, new_b_types = self.sentence_tokenize(sentence_b, 1, random_start_b == 0, random_start_b == len(doc_b))
b_len += len(new_b_tokens)
tokens_b.extend(new_b_tokens)
token_types_b.extend(new_b_types)
if len(tokens_b) >= target_b_length:
break
random_start_b = (random_start_b+1)
else:
is_random_next = False
for j in range(num_a, len(curr_strs)):
tokens_b.extend(curr_strs[j])
token_types_b.extend(curr_str_types[j])
return (tokens_a, token_types_a), (tokens_b, token_types_b), is_random_next
def truncate_seq_pair(self, a, b, max_seq_len, rng):
"""
Truncate sequence pair according to original BERT implementation:
https://github.com/google-research/bert/blob/master/create_pretraining_data.py#L391
"""
tokens_a, token_types_a = a
tokens_b, token_types_b = b
max_num_tokens = max_seq_len - 3
while True:
len_a = len(tokens_a)
len_b = len(tokens_b)
total_length = len_a + len_b
if total_length <= max_num_tokens:
break
if len(tokens_a) > len(tokens_b):
trunc_tokens = tokens_a
trunc_types = token_types_a
else:
trunc_tokens = tokens_b
trunc_types = token_types_b
assert len(trunc_tokens) >= 1
if rng.random() < 0.5:
trunc_tokens.pop(0)
trunc_types.pop(0)
else:
trunc_tokens.pop()
trunc_types.pop()
return (tokens_a, token_types_a), (tokens_b, token_types_b)
def mask_token(self, idx, tokens, types, vocab_words, rng):
"""
helper function to mask `idx` token from `tokens` according to
section 3.3.1 of https://arxiv.org/pdf/1810.04805.pdf
"""
label = tokens[idx]
if rng.random() < 0.8:
new_label = self.tokenizer.get_command('MASK').Id
else:
if rng.random() < 0.5:
new_label = label
else:
new_label = rng.choice(vocab_words)
tokens[idx] = new_label
return label
def pad_seq(self, seq):
"""helper function to pad sequence pair"""
num_pad = max(0, self.max_seq_len - len(seq))
pad_mask = [0] * len(seq) + [1] * num_pad
seq += [self.tokenizer.get_command('pad').Id] * num_pad
return seq, pad_mask
def create_masked_lm_predictions(self, a, b, mask_lm_prob, max_preds_per_seq, vocab_words, rng):
"""
Mask sequence pair for BERT training according to:
https://github.com/google-research/bert/blob/master/create_pretraining_data.py#L338
"""
tokens_a, token_types_a = a
tokens_b, token_types_b = b
tokens = [self.tokenizer.get_command('ENC').Id] + tokens_a + [self.tokenizer.get_command('sep').Id] + tokens_b + [self.tokenizer.get_command('sep').Id]
token_types = [token_types_a[0]] + token_types_a + [token_types_a[0]] + token_types_b + [token_types_b[0]]
len_a = len(tokens_a)
len_b = len(tokens_b)
cand_indices = [idx+1 for idx in range(len_a)] + [idx+2+len_a for idx in range(len_b)]
rng.shuffle(cand_indices)
output_tokens, pad_mask = self.pad_seq(list(tokens))
output_types, _ = self.pad_seq(list(token_types))
num_to_predict = min(max_preds_per_seq, max(1, int(round(len(tokens) * mask_lm_prob))))
mask = [0] * len(output_tokens)
mask_labels = [-1] * len(output_tokens)
for idx in sorted(cand_indices[:num_to_predict]):
mask[idx] = 1
label = self.mask_token(idx, output_tokens, output_types, vocab_words, rng)
mask_labels[idx] = label
return (output_tokens, output_types), mask, mask_labels, pad_mask
import nltk
import glob
import json
import os
nltk.download('punkt')
class NLTKSegmenter:
def __init(self):
pass
@staticmethod
def segment_string(article):
return nltk.tokenize.sent_tokenize(article)
wiki_path = "data/extracted"
output_path = "formatted/wiki-key.txt"
segmenter = NLTKSegmenter()
with open(output_path, "w") as output:
for dirname in glob.glob(os.path.join(wiki_path, '*'), recursive=False):
for filename in glob.glob(os.path.join(dirname, 'wiki_*'), recursive=True):
print(filename)
article_lines = []
article_open = False
with open(filename, mode='r', newline='\n') as file:
for line in file:
line = line.rstrip()
if '<doc id=' in line:
article_open = True
elif '</doc>' in line:
key_sentences, contents = [], []
key, content = None, []
for sentences in article_lines[1:]:
if len(sentences) > 1:
if key:
if len(content) > 0 or len(contents) == 0:
key_sentences.append(key)
contents.append(content)
else:
contents[-1].append(key)
key, content = None, []
key_sentences.append(sentences[0])
contents.append(sentences[1:])
elif len(sentences) > 0:
if key:
content.append(sentences[0])
else:
key = sentences[0]
if key:
if len(content) > 0 or len(contents) == 0:
key_sentences.append(key)
contents.append(content)
else:
contents[-1].append(key)
contents = [" ".join(content) for content in contents]
article = {"key": key_sentences, "content": contents}
output.write(json.dumps(article))
output.write("\n")
article_open = False
article_lines = []
else:
if article_open and line:
sentences = segmenter.segment_string(line)
article_lines.append(sentences)
# This file is provided as is from:
# https://github.com/huggingface/pytorch-pretrained-BERT
# Please refer to their repository for copyright.
"""
Utilities for working with the local dataset cache.
This file is adapted from the AllenNLP library at https://github.com/allenai/allennlp
Copyright by the AllenNLP authors.
"""
from __future__ import (absolute_import, division, print_function, unicode_literals)
import json
import logging
import os
import shutil
import tempfile
from functools import wraps
from hashlib import sha256
import sys
from io import open
import boto3
import requests
from botocore.exceptions import ClientError
from tqdm import tqdm
try:
from urllib.parse import urlparse
except ImportError:
from urlparse import urlparse
try:
from pathlib import Path
PYTORCH_PRETRAINED_BERT_CACHE = Path(os.getenv('PYTORCH_PRETRAINED_BERT_CACHE',
Path.home() / '.pytorch_pretrained_bert'))
except (AttributeError, ImportError):
PYTORCH_PRETRAINED_BERT_CACHE = os.getenv('PYTORCH_PRETRAINED_BERT_CACHE',
os.path.join(os.path.expanduser("~"), '.pytorch_pretrained_bert'))
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
def url_to_filename(url, etag=None):
"""
Convert `url` into a hashed filename in a repeatable way.
If `etag` is specified, append its hash to the url's, delimited
by a period.
"""
url_bytes = url.encode('utf-8')
url_hash = sha256(url_bytes)
filename = url_hash.hexdigest()
if etag:
etag_bytes = etag.encode('utf-8')
etag_hash = sha256(etag_bytes)
filename += '.' + etag_hash.hexdigest()
return filename
def filename_to_url(filename, cache_dir=None):
"""
Return the url and etag (which may be ``None``) stored for `filename`.
Raise ``EnvironmentError`` if `filename` or its stored metadata do not exist.
"""
if cache_dir is None:
cache_dir = PYTORCH_PRETRAINED_BERT_CACHE
if sys.version_info[0] == 3 and isinstance(cache_dir, Path):
cache_dir = str(cache_dir)
cache_path = os.path.join(cache_dir, filename)
if not os.path.exists(cache_path):
raise EnvironmentError("file {} not found".format(cache_path))
meta_path = cache_path + '.json'
if not os.path.exists(meta_path):
raise EnvironmentError("file {} not found".format(meta_path))
with open(meta_path, encoding="utf-8") as meta_file:
metadata = json.load(meta_file)
url = metadata['url']
etag = metadata['etag']
return url, etag
def cached_path(url_or_filename, cache_dir=None):
"""
Given something that might be a URL (or might be a local path),
determine which. If it's a URL, download the file and cache it, and
return the path to the cached file. If it's already a local path,
make sure the file exists and then return the path.
"""
if cache_dir is None:
cache_dir = PYTORCH_PRETRAINED_BERT_CACHE
if sys.version_info[0] == 3 and isinstance(url_or_filename, Path):
url_or_filename = str(url_or_filename)
if sys.version_info[0] == 3 and isinstance(cache_dir, Path):
cache_dir = str(cache_dir)
parsed = urlparse(url_or_filename)
if parsed.scheme in ('http', 'https', 's3'):
# URL, so get it from the cache (downloading if necessary)
return get_from_cache(url_or_filename, cache_dir)
elif os.path.exists(url_or_filename):
# File, and it exists.
return url_or_filename
elif parsed.scheme == '':
# File, but it doesn't exist.
raise EnvironmentError("file {} not found".format(url_or_filename))
else:
# Something unknown
raise ValueError("unable to parse {} as a URL or as a local path".format(url_or_filename))
def split_s3_path(url):
"""Split a full s3 path into the bucket name and path."""
parsed = urlparse(url)
if not parsed.netloc or not parsed.path:
raise ValueError("bad s3 path {}".format(url))
bucket_name = parsed.netloc
s3_path = parsed.path
# Remove '/' at beginning of path.
if s3_path.startswith("/"):
s3_path = s3_path[1:]
return bucket_name, s3_path
def s3_request(func):
"""
Wrapper function for s3 requests in order to create more helpful error
messages.
"""
@wraps(func)
def wrapper(url, *args, **kwargs):
try:
return func(url, *args, **kwargs)
except ClientError as exc:
if int(exc.response["Error"]["Code"]) == 404:
raise EnvironmentError("file {} not found".format(url))
else:
raise
return wrapper
@s3_request
def s3_etag(url):
"""Check ETag on S3 object."""
s3_resource = boto3.resource("s3")
bucket_name, s3_path = split_s3_path(url)
s3_object = s3_resource.Object(bucket_name, s3_path)
return s3_object.e_tag
@s3_request
def s3_get(url, temp_file):
"""Pull a file directly from S3."""
s3_resource = boto3.resource("s3")
bucket_name, s3_path = split_s3_path(url)
s3_resource.Bucket(bucket_name).download_fileobj(s3_path, temp_file)
def http_get(url, temp_file):
req = requests.get(url, stream=True)
content_length = req.headers.get('Content-Length')
total = int(content_length) if content_length is not None else None
progress = tqdm(unit="B", total=total)
for chunk in req.iter_content(chunk_size=1024):
if chunk: # filter out keep-alive new chunks
progress.update(len(chunk))
temp_file.write(chunk)
progress.close()
def get_from_cache(url, cache_dir=None):
"""
Given a URL, look for the corresponding dataset in the local cache.
If it's not there, download it. Then return the path to the cached file.
"""
if cache_dir is None:
cache_dir = PYTORCH_PRETRAINED_BERT_CACHE
if sys.version_info[0] == 3 and isinstance(cache_dir, Path):
cache_dir = str(cache_dir)
if not os.path.exists(cache_dir):
os.makedirs(cache_dir)
# Get eTag to add to filename, if it exists.
if url.startswith("s3://"):
etag = s3_etag(url)
else:
response = requests.head(url, allow_redirects=True)
if response.status_code != 200:
raise IOError("HEAD request failed for url {} with status code {}"
.format(url, response.status_code))
etag = response.headers.get("ETag")
filename = url_to_filename(url, etag)
# get cache path to put the file
cache_path = os.path.join(cache_dir, filename)
if not os.path.exists(cache_path):
# Download to temporary file, then copy to cache dir once finished.
# Otherwise you get corrupt cache entries if the download gets interrupted.
with tempfile.NamedTemporaryFile() as temp_file:
logger.info("%s not found in cache, downloading to %s", url, temp_file.name)
# GET file object
if url.startswith("s3://"):
s3_get(url, temp_file)
else:
http_get(url, temp_file)
# we are copying the file before closing it, so flush to avoid truncation
temp_file.flush()
# shutil.copyfileobj() starts at the current position, so go to the start
temp_file.seek(0)
logger.info("copying %s to cache at %s", temp_file.name, cache_path)
with open(cache_path, 'wb') as cache_file:
shutil.copyfileobj(temp_file, cache_file)
logger.info("creating metadata file for %s", cache_path)
meta = {'url': url, 'etag': etag}
meta_path = cache_path + '.json'
with open(meta_path, 'w', encoding="utf-8") as meta_file:
json.dump(meta, meta_file)
logger.info("removing temp file %s", temp_file.name)
return cache_path
def read_set_from_file(filename):
'''
Extract a de-duped collection (set) of text from a file.
Expected file format is one item per line.
'''
collection = set()
with open(filename, 'r', encoding='utf-8') as file_:
for line in file_:
collection.add(line.rstrip())
return collection
def get_file_extension(path, dot=True, lower=True):
ext = os.path.splitext(path)[1]
ext = ext if dot else ext[1:]
return ext.lower() if lower else ext
# coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""utils for loading text from disk"""
import os
import mmap
import pickle as pkl
import time
import numpy as np
from itertools import accumulate
import torch
from torch.multiprocessing import Lock
def get_lazy_path(path):
"""
Gets directory path where lazy files are stored.
"""
return os.path.splitext(path)[0]+'.lazy'
def exists_lazy(path, data_type='data'):
"""
Check if we've already made a lazy version of this file for the `data_type` field.
"""
if not os.path.exists(get_lazy_path(path)):
return False
contents = os.listdir(get_lazy_path(path))
if data_type not in contents:
return False
if data_type+'.len.pkl' not in contents:
return False
return True
class LazyWriter:
def __init__(self, path, data_type, is_array=False, array_data_type=np.int32):
lazypath = get_lazy_path(path)
if not os.path.exists(lazypath):
os.makedirs(lazypath)
self.datapath = os.path.join(lazypath, data_type)
self.lenpath = os.path.join(lazypath, data_type + '.len.pkl')
self.array_data_type = array_data_type
self.output = open(self.datapath, 'wb')
self.lengths = []
self.is_array = is_array
@staticmethod
def get_len_path(path, data_type):
lazypath = get_lazy_path(path)
return os.path.join(lazypath, data_type + '.len.pkl')
def write(self, s):
if isinstance(s, dict):
s = s['text']
if self.is_array:
encoded = np.array(s, dtype=self.array_data_type).tobytes(order='C')
self.output.write(encoded)
self.lengths.append(len(s))
else:
encoded = s.encode('utf-8')
self.output.write(encoded)
self.lengths.append(len(encoded))
def close(self):
self.output.close()
with open(self.lenpath, 'wb') as f:
pkl.dump(self.lengths, f)
def split_strings(strings, start, chr_lens):
"""
Split strings based on string lengths and given start.
"""
return [strings[i-start:j-start] for i, j in zip([start]+chr_lens[:-1], chr_lens)]
class ProcessorTokenizer:
"""
callable class that runs a preprocessing, as well as tokenization step,
on input text.
"""
def __init__(self, tokenizer, process_fn=None):
self.tokenizer = tokenizer
self.process_fn = process_fn
def __call__(self, string):
if self.tokenizer is not None:
string = self.tokenizer(string, process_fn=self.process_fn)
elif self.process_fn is not None:
string = self.process_fn(string)
return string
class LazyLoader(object):
"""
Arguments:
path: path to directory where array entries are concatenated into one big string file
and the .len file are located
data_type (str): Some datsets have multiple fields that are stored in different paths.
`data_type` specifies which of these fields to load in this class
mem_map (boolean): Specifies whether to memory map file `path`
map_fn (callable): Fetched strings are passed through map_fn before being returned.
Example of lazy loader directory structure:
file.json
file.lazy/
data_type1
data_type1.len.pkl
data_type2
data_type2.len.pkl
"""
def __init__(self, path, data_type='data', mem_map=False, map_fn=None, is_array=False, array_data_type=np.int32):
lazypath = get_lazy_path(path)
datapath = os.path.join(lazypath, data_type)
#get file where array entries are concatenated into one big string
self._file = open(datapath, 'rb')
self.file = self._file
self.is_array = is_array
self.array_data_type = array_data_type
#memory map file if necessary
lenpath = os.path.join(lazypath, data_type+'.len.pkl')
self.lens = pkl.load(open(lenpath, 'rb'))
self.ends = list(accumulate(self.lens))
self.dumb_ends = list(self.ends)
self.mem_map = mem_map
if self.mem_map:
if is_array:
if self.ends[-1] == 0:
self.file = np.array([], dtype=array_data_type)
else:
self.file = np.memmap(self.file, dtype=array_data_type, mode='r', order='C')
else:
if self.ends[-1] == 0:
self.file = bytearray()
else:
self.file = mmap.mmap(self.file.fileno(), 0, prot=mmap.PROT_READ)
self.read_lock = Lock()
self.process_fn = map_fn
self.map_fn = map_fn
self._tokenizer = None
self.is_lazy = True
def SetTokenizer(self, tokenizer):
"""
logic to set and remove (set to None) tokenizer.
combines preprocessing/tokenization into one callable.
"""
if tokenizer is None:
if not hasattr(self, '_tokenizer'):
self._tokenizer = tokenizer
else:
self._tokenizer = tokenizer
self.map_fn = ProcessorTokenizer(tokenizer, self.process_fn)
def GetTokenizer(self):
return self._tokenizer
def __getitem__(self, index):
"""
read file and splice strings based on string ending array `self.ends`
"""
if not isinstance(index, slice):
if index == 0:
start = 0
else:
start = self.ends[index-1]
end = self.ends[index]
rtn = self.file_read(start, end)
if self.map_fn is not None:
return self.map_fn(rtn)
else:
# if slice, fetch strings with 1 diskread and then splice in memory
chr_lens = self.ends[index]
if index.start == 0 or index.start is None:
start = 0
else:
start = self.ends[index.start-1]
stop = chr_lens[-1]
strings = self.file_read(start, stop)
rtn = split_strings(strings, start, chr_lens)
if self.map_fn is not None:
return self.map_fn([s for s in rtn])
return rtn
def __len__(self):
return len(self.ends)
def file_read(self, start=0, end=None):
"""read specified portion of file"""
data_type_size = np.dtype(self.array_data_type).itemsize
# atomic reads to avoid race conditions with multiprocess dataloader
self.read_lock.acquire()
if not self.mem_map:
# seek to start of file read
if self.is_array:
start = start * data_type_size
end = end * data_type_size if end is not None else None
self.file.seek(start)
# read to end of file if no end point provided
if end is None:
rtn = self.file.read()
#else read amount needed to reach end point
else:
rtn = self.file.read(end-start)
if self.is_array:
rtn = np.ndarray(shape=(len(rtn) / data_type_size,), dtype=self.array_data_type, buffer=rtn, order='C')
else:
rtn = rtn.decode('utf-8', 'ignore')
else:
rtn = self.file[start:end]
if self.is_array:
rtn = rtn.copy()
else:
rtn = rtn.decode('utf-8', 'strict')
self.read_lock.release()
#TODO: @raulp figure out mem map byte string bug
#if mem map'd need to decode byte string to string
# # rtn = str(rtn)
# if self.mem_map:
# rtn = rtn.decode('unicode_escape')
return rtn
# coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""batch samplers that work with either random or sequential data samplers"""
import math
import os
import sys
import torch
from torch.utils import data
import numpy as np
class RandomSampler(data.sampler.Sampler):
r"""
Based off of pytorch RandomSampler and DistributedSampler. Essentially a RandomSampler,
but this class lets the user set an epoch like DistributedSampler
Samples elements randomly. If without replacement, then sample from a shuffled dataset.
If with replacement, then user can specify ``num_samples`` to draw.
Arguments:
data_source (Dataset): dataset to sample from
num_samples (int): number of samples to draw, default=len(dataset)
replacement (bool): samples are drawn with replacement if ``True``, default=False
"""
def __init__(self, data_source, replacement=False, num_samples=None):
self.data_source = data_source
self.replacement = replacement
self._num_samples = num_samples
self.epoch = -1
if self._num_samples is not None and replacement is False:
raise ValueError("With replacement=False, num_samples should not be specified, "
"since a random permute will be performed.")
if not isinstance(self.num_samples, int) or self.num_samples <= 0:
raise ValueError("num_samples should be a positive integer "
"value, but got num_samples={}".format(self.num_samples))
if not isinstance(self.replacement, bool):
raise ValueError("replacement should be a boolean value, but got "
"replacement={}".format(self.replacement))
@property
def num_samples(self):
# dataset size might change at runtime
if self._num_samples is None:
return len(self.data_source)
return self._num_samples
def __iter__(self):
n = len(self.data_source)
g = torch.Generator()
if self.epoch >= 0:
g.manual_seed(self.epoch)
if self.replacement:
return iter(torch.randint(high=n, size=(self.num_samples,), dtype=torch.int64, generator=g).tolist())
return iter(torch.randperm(n, generator=g).tolist())
def __len__(self):
return self.num_samples
def set_epoch(self, epoch):
self.epoch = epoch
class DistributedSequentialSampler(data.sampler.Sampler):
def __init__(self, num_samples, train_iters, batch_size, rank=-1, world_size=2):
super().__init__(num_samples)
if rank == -1:
rank = 0
world_size = 1
self.num_samples = num_samples
self.rank = rank
self.world_size = world_size
self.start_iter = 0
self.train_iters = train_iters
self.batch_size = batch_size
self.batch_bias = [i * (num_samples // batch_size) for i in range(batch_size)]
def __iter__(self):
for idx in range(self.start_iter, self.train_iters * 10):
batch = [(idx + bias) % self.num_samples for bias in self.batch_bias]
tbatch = self._batch(batch)
yield tbatch
def __len__(self):
return self.train_iters
def _batch(self, batch):
"""extracts samples only pertaining to this worker's batch"""
start = self.rank*self.batch_size//self.world_size
end = (self.rank+1)*self.batch_size//self.world_size
return batch[start:end]
class DistributedBatchSampler(data.sampler.BatchSampler):
"""
similar to normal implementation of distributed sampler, except implementation is at the
batch sampler level, instead of just the sampler level. This allows wrapping of arbitrary
data samplers (sequential, random, WeightedRandomSampler, etc.) with this batch sampler.
"""
def __init__(self, sampler, batch_size, drop_last, rank=-1, world_size=2, wrap_last=False, gradient_accumulation_steps=None):
super(DistributedBatchSampler, self).__init__(sampler, batch_size, drop_last)
if rank == -1:
assert False, 'should not be here'
self.rank = rank
self.world_size = world_size
self.sampler.wrap_around = 0
self.wrap_around = 0
self.wrap_last = wrap_last
self.start_iter = 0
self.effective_batch_size = batch_size if gradient_accumulation_steps is None else batch_size * gradient_accumulation_steps
def __iter__(self):
batch = []
i = 0
for idx in self.data_iterator(self.sampler, wrap_around=False):
batch.append(idx)
if len(batch) == self.batch_size:
tbatch = self._batch(batch)
if i >= self.start_iter * self.effective_batch_size:
yield tbatch
self.start_iter = 0
i += len(batch)
batch = []
batch_len = len(batch)
if batch_len > 0 and not self.drop_last:
if self.wrap_last:
self.sampler.wrap_around -= (self.batch_size)
self.wrap_around += (len(batch))
self.wrap_around %= self.batch_size
if isinstance(self.sampler, TransposedSampler):
for i, idx in enumerate(self.data_iterator(self.sampler, wrap_around=True)):
if i == 0:
continue
batch.append(idx)
new_batch_len = len(batch)
if len(batch) == self.batch_size:
break
yield self._batch(batch)
if self.wrap_last:
self.sampler.wrap_around += self.batch_size
def data_iterator(self, _iter, wrap_around=False):
"""iterates through data and handles wrap around"""
for i, idx in enumerate(_iter):
if i < self.wrap_around%self.batch_size:
continue
if wrap_around:
self.wrap_around += 1
self.wrap_around %= self.batch_size
yield idx
def _batch(self, batch):
"""extracts samples only pertaining to this worker's batch"""
start = self.rank*self.batch_size//self.world_size
end = (self.rank+1)*self.batch_size//self.world_size
return batch[start:end]
"""
from https://github.com/openai/gpt-2/, changed for chinese
"""
import json
import os
import sentencepiece as spm
"""
SentencePiece is an unsupervised text tokenizer and detokenizer mainly for Neural Network-based text generation
systems where the vocabulary size is predetermined prior to the neural model training. SentencePiece implements
subword units (e.g., byte-pair-encoding (BPE) [Sennrich et al.]) and unigram language model [Kudo.]) with the
extension of direct training from raw sentences. SentencePiece allows us to make a purely end-to-end
system that does not depend on language-specific pre/postprocessing.
https://github.com/google/sentencepiece
pip install sentencepiece
or git clone https://github.com/google/sentencepiece.git
python setup.py install
"""
PRETRAINED_MODEL_FILE = "/home/zengzhongshen/desktop/transformer_xl_chinese/chinese_sentencepiece/cog-pretrain.model"
def get_pairs(word):
pairs = set()
prev_char = word[0]
for char in word[1:]:
pairs.add((prev_char, char))
prev_char = char
return pairs
class Encoder:
def __init__(self, encoder, bpe_merges):
self.encoder = encoder
self.decoder = {v: k for k, v in self.encoder.items()}
self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges))))
self.cache = {}
self.max_len = 0
def bpe(self, token):
if token in self.cache:
return self.cache[token]
word = tuple(token)
pairs = get_pairs(word)
if not pairs:
return token
while True:
bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float('inf')))
if bigram not in self.bpe_ranks:
break
first, second = bigram
new_word = []
i = 0
while i < len(word):
try:
j = word.index(first, i)
new_word.extend(word[i:j])
i = j
except:
new_word.extend(word[i:])
break
if word[i] == first and i < len(word) - 1 and word[i + 1] == second:
new_word.append(first + second)
i += 2
else:
new_word.append(word[i])
i += 1
new_word = tuple(new_word)
word = new_word
if len(word) == 1:
break
else:
pairs = get_pairs(word)
word = ' '.join(word)
self.cache[token] = word
return word
def encode(self, text):
return [self.encoder.get(token, 1) for token in self.tokenize(text)]
def decode(self, tokens):
text = ''.join([self.decoder[token] for token in tokens])
return text
def tokenize(self, text):
bpe_tokens = []
bpe_tokens.extend(bpe_token for bpe_token in self.bpe(text).split(' '))
return bpe_tokens
def convert_tokens_to_ids(self, tokens):
return [self.encoder.get(token, 1) for token in tokens]
class Encoder_SP:
def __init__(self, model_path):
self.sp = spm.SentencePieceProcessor()
self.sp.Load(model_path)
def encode(self, text):
"""
text="...."
"""
return self.sp.EncodeAsIds(text)
def decode(self, tokens):
"""
tokens=[x1,x2,...]
"""
text = [int(token) for token in tokens]
# print(text)
return self.sp.DecodeIds(text)
def tokenize(self, text):
return self.sp.EncodeAsPieces(text)
def convert_tokens_to_ids(self, tokens):
return [self.sp.PieceToId(token) for token in tokens]
def convert_token_to_id(self, token):
return self.sp.PieceToId(token)
def convert_id_to_token(self, idx):
return self.sp.IdToPiece(idx)
def get_encoder(encoder_file, bpe_file):
# 以下是为了同一个函数入兼容sentencepiece
filepath, filename = os.path.split(encoder_file)
shotname, extension = os.path.splitext(filename)
if (".model" == extension) and (bpe_file == ""):
return Encoder_SP(encoder_file)
else:
with open(encoder_file, 'r', encoding="utf-8") as f:
encoder = json.load(f)
with open(bpe_file, 'r', encoding="utf-8") as f:
bpe_data = f.read()
bpe_merges = [tuple(merge_str.split()) for merge_str in bpe_data.split('\n')[1:-1]]
return Encoder(
encoder=encoder,
bpe_merges=bpe_merges,
)
def from_pretrained():
return get_encoder(PRETRAINED_MODEL_FILE, "")
\ No newline at end of file
# coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch DataLoader for TFRecords"""
import queue
import threading
import tensorflow as tf
import utils
tf.enable_eager_execution()
import torch
import numpy as np
class TFRecordDataLoader(object):
def __init__(self, records, batch_size, max_seq_len, max_preds_per_seq, train, num_workers=2, seed=1, threaded_dl=False):
assert max_preds_per_seq is not None, "--max-preds-per-seq MUST BE SPECIFIED when using tfrecords"
utils.set_random_seed(seed)
if isinstance(records, str):
records = [records]
self.record_converter = Record2Example({"input_ids": tf.FixedLenFeature([max_seq_len], tf.int64),
"input_mask": tf.FixedLenFeature([max_seq_len], tf.int64),
"segment_ids": tf.FixedLenFeature([max_seq_len], tf.int64),
"masked_lm_positions": tf.FixedLenFeature([max_preds_per_seq], tf.int64),
"masked_lm_ids": tf.FixedLenFeature([max_preds_per_seq], tf.int64),
"masked_lm_weights": tf.FixedLenFeature([max_preds_per_seq], tf.float32),
"next_sentence_labels": tf.FixedLenFeature([1], tf.int64)})
#Instantiate dataset according to original BERT implementation
if train:
self.dataset = tf.data.Dataset.from_tensor_slices(tf.constant(records))
self.dataset = self.dataset.repeat()
self.dataset = self.dataset.shuffle(buffer_size=len(records))
# use sloppy tfrecord dataset
self.dataset = self.dataset.apply(
tf.contrib.data.parallel_interleave(
tf.data.TFRecordDataset,
sloppy=train,
cycle_length=min(num_workers, len(records))))
self.dataset = self.dataset.shuffle(buffer_size=100)
else:
self.dataset = tf.data.TFRecordDataset(records)
self.dataset = self.dataset.repeat()
# Instantiate dataloader (do not drop remainder for eval)
loader_args = {'batch_size': batch_size,
'num_parallel_batches': num_workers,
'drop_remainder': train}
self.dataloader = self.dataset.apply(tf.contrib.data.map_and_batch(self.record_converter, **loader_args))
self.threaded_dl = threaded_dl
self.num_workers = num_workers
def __iter__(self):
if self.threaded_dl:
data_iter = iter(MultiprocessLoader(self.dataloader, self.num_workers))
for item in data_iter:
yield item
else:
data_iter = iter(self.dataloader)
for item in data_iter:
yield convert_tf_example_to_torch_tensors(item)
class Record2Example(object):
def __init__(self, feature_map):
self.feature_map = feature_map
def __call__(self, record):
"""Decodes a BERT TF record to a TF example."""
example = tf.parse_single_example(record, self.feature_map)
for k, v in list(example.items()):
if v.dtype == tf.int64:
example[k] = tf.to_int32(v)
return example
def convert_tf_example_to_torch_tensors(example):
item = {k: (v.numpy()) for k,v in example.items()}
mask = np.zeros_like(item['input_ids'])
mask_labels = np.ones_like(item['input_ids'])*-1
for b, row in enumerate(item['masked_lm_positions'].astype(int)):
for i, idx in enumerate(row):
if item['masked_lm_weights'][b, i] != 0:
mask[b, idx] = 1
mask_labels[b, idx] = item['masked_lm_ids'][b, i]
output = {'text': item['input_ids'], 'types': item['segment_ids'],'is_random': item['next_sentence_labels'],
'pad_mask': 1-item['input_mask'], 'mask': mask, 'mask_labels': mask_labels}
return {k: torch.from_numpy(v) for k,v in output.items()}
class MultiprocessLoader(object):
def __init__(self, dataloader, num_workers=2):
self.dl = dataloader
self.queue_size = 2*num_workers
def __iter__(self):
output_queue = queue.Queue(self.queue_size)
output_thread = threading.Thread(target=_multiproc_iter,
args=(self.dl, output_queue))
output_thread.daemon = True
output_thread.start()
while output_thread.is_alive():
yield output_queue.get(block=True)
else:
print(RuntimeError('TF record data loader thread exited unexpectedly'))
def _multiproc_iter(dl, output_queue):
data_iter = iter(dl)
for item in data_iter:
tensors = convert_tf_example_to_torch_tensors(item)
output_queue.put(tensors, block=True)
\ No newline at end of file
# coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utilities for using and training tokenizers (char, wordpiece, sentencepiece)"""
from collections import namedtuple
import random
import os
import csv
import torch
import nltk
from nltk import tokenize as nltk_tokenize
import sentencepiece as spm
from .wordpiece import BertTokenizer, PRETRAINED_VOCAB_ARCHIVE_MAP
from .tokenization_gpt2 import GPT2Tokenizer
from . import sp_tokenizer
import regex as re
def make_tokenizer(tokenizer_type, corpus, model_path=None, vocab_size=None, model_type='bpe', pad_token=0, character_coverage=1.0, command_tokens=None, type_tokens=None, **kwargs):
"""
Helper function to instantiate a tokenizer given common combinations of options.
"""
tokenizer_class = tokenizer_type
if isinstance(tokenizer_class, str):
tokenizer_class = eval(tokenizer_class)
if tokenizer_class is BertWordPieceTokenizer:
return BertWordPieceTokenizer(model_type, **kwargs)
elif tokenizer_class is GPT2BPETokenizer:
return GPT2BPETokenizer(**kwargs)
elif tokenizer_class is ChineseSPTokenizer:
return ChineseSPTokenizer(**kwargs)
text_tokenizer = tokenizer_class(corpus=corpus, vocab_size=vocab_size, model_path=model_path, model_type=model_type,
pad_token=pad_token, character_coverage=character_coverage)
return Tokenizer(text_tokenizer, command_tokens, type_tokens)
class Tokenization(object):
"""
Tokenization object to hold tokenization, (processed text),and original
text. Can hold tokenization as Ids or tokens.
It also holds command tokens (pad, unk, etc.) for the tokenization.
This allows functions to pad/operate on tokenizations without having
access to the full tokenizer, just the tokenization.
Several standard array operations are implemented (insert, append, extend).
"""
def __init__(self, tokenization, text=None, original_text=None, command_tokens=None, asIds=True):
self.tokenization = tokenization
self.text = text
if self.text is None:
self.text = self.tokenization
self.original_text = original_text
if self.original_text is None:
self.original_text = self.text
self.command_tokens = command_tokens
self.asIds = asIds
self.parse_command_tokens()
def set_command_tokens(self, command_tokens):
self.command_tokens = command_tokens
return self.parse_command_tokens()
def parse_command_tokens(self):
if self.command_tokens is None:
return
for command_token in self.command_tokens:
if self.asIds:
setattr(self, command_token.name, command_token.Id)
else:
setattr(self, command_token.name, command_token.token)
def __getitem__(self, index):
return self.tokenization[index]
def __len__(self):
return len(self.tokenization)
def insert(self, idx, other):
if isinstance(other, (CommandToken, TypeToken)):
self.tokenization.insert(idx, other.Id)
if idx == 0:
self.text = other.token + self.text
self.original_text = other.token + self.original_text
elif idx == len(self.tokenization)-1:
self.text += other.token
self.original_text += other.token
elif isinstance(other, Tokenization):
self.tokenization = self.tokenization[:idx] + other.tokenization + self.tokenization[idx:]
else:
self.tokenization = self.tokenization[:idx] + other.tokenization + self.tokenization[idx:]
def append(self, other):
if isinstance(other, (CommandToken, TypeToken)):
self.tokenization.append(other.Id)
self.text += other.token
self.original_text += other.token
elif isinstance(other, Tokenization):
self.tokenization.extend(other.tokenization)
self.text += other.text
self.original_text += other.original_text
else:
self.tokenization.append(other)
return self
def extend(self, other):
if isinstance(other, (CommandToken, TypeToken)):
self.tokenization.append(other.Id)
self.text += other.token
self.original_text += other.token
elif isinstance(other, list) and isinstance(other[0], (CommandToken, TypeToken)):
self.tokenization.extend([o.Id for o in other])
self.text += [o.token for o in other]
self.original_text += [o.token for o in other]
elif isinstance(other, Tokenization):
self.tokenization.extend(other.tokenization)
self.text += other.text
self.original_text += other.original_text
else:
self.tokenization.extend(other)
return self
"""define some default command tokens for the tokenizer to use"""
token_format = "<{0}>"
COMMAND_TUPLE = namedtuple('CommandToken', ('name', 'token', 'Id'))
def prep_command_tokens(tokenlist, token_format=token_format):
return [CommandToken(tok[0], token_format.format(tok[0]), tok[1]) for tok in tokenlist]
class CommandToken(object):
def __init__(self, name, token, Id):
self.name = name
self.token = token
self.Id = Id
def __str__(self):
return str(COMMAND_TUPLE(self.name, self.token, self.Id))
DEFAULT_COMMAND_TOKENS = [
('pad', 0),
('eos', 1),
('bos', 2),
('unk', 3),
('sep', 4),
('L2R', 5),
('ENC', 6),
('MASK', 7),
]
DEFAULT_COMMAND_TOKENS = prep_command_tokens(DEFAULT_COMMAND_TOKENS)
"""define some default type tokens for bert training"""
TYPE_TUPLE = namedtuple('TypeToken', ('name', 'token', 'Id'))
def prep_type_tokens(tokenlist, token_format=token_format):
return [TypeToken(tok[0], token_format.format(tok[0]), tok[1]) for tok in tokenlist]
class TypeToken(object):
def __init__(self, name, token, Id):
self.name = name
self.token = token
self.Id = Id
def __str__(self):
return str(TYPE_TUPLE(self.name, self.token, self.Id))
DEFAULT_TYPE_TOKENS = [
('function', 0),
('command', 1),
('str0', 2),
('str1', 3),
('str2', 4),
('embedding0', 5),
('embedding1', 6),
('embedding2', 7),
('arg0', 8),
('arg1', 9),
('arg2', 10),
]
DEFAULT_TYPE_TOKENS = prep_type_tokens(DEFAULT_TYPE_TOKENS)
class Tokenizer(object):
"""
Tokenizer object that handles text tokenization, command tokens, and type tokens.
Command tokens and text tokens are stored together in one mapping of size
`len(text_tokenizer)+len(command_tokens)`. Command tokens are stored as first
`len(command_tokens)` tokens. Token idx is stored at `idx+len(command_tokens)`.
Token types are stored in a separate mapping of size `len(type_tokens)`.
"""
def __init__(self, text_tokenizer, command_tokens=None, type_tokens=None):
# set text tokenizer
self.text_tokenizer = text_tokenizer
if not hasattr(self, 'num_text_tokens'):
self.num_text_tokens = len(self.text_tokenizer)
# set command tokens
if command_tokens is None:
command_tokens = DEFAULT_COMMAND_TOKENS
self._command_tokens = command_tokens
self.command_name_map = {tok.name: tok for tok in self._command_tokens}
self.command_token_map = {tok.token: tok for tok in self._command_tokens}
self.command_id_map = {tok.Id: tok for tok in self._command_tokens}
if not hasattr(self, 'num_command_tokens'):
self.num_command_tokens = len(self._command_tokens)
if not hasattr(self, 'num_tokens'):
self.num_tokens = self.num_command_tokens + self.num_text_tokens
# set type tokens
if type_tokens is None:
type_tokens = DEFAULT_TYPE_TOKENS
self.type_tokens = type_tokens
self.type_name_map = {tok.name: tok for tok in self.type_tokens}
self.type_token_map = {tok.token: tok for tok in self.type_tokens}
self.type_id_map = {tok.Id: tok for tok in self.type_tokens}
if not hasattr(self, 'num_type_tokens'):
self.num_type_tokens = len(self.type_tokens)
# parse tokens and vocabs from tokenizer
self._tokens = list(self.command_token_map.keys()) + list(self.text_tokenizer.tokens)
self._vocab = {t:Id for Id,t in self.command_id_map.items()}
self._vocab.update({t:Id+self.num_command_tokens for t,Id in self.text_tokenizer.vocab.items()})
self._text_tokens = list(self.text_tokenizer.tokens)
self._text_token_vocab = {t:Id+self.num_command_tokens for t,Id in self.text_tokenizer.vocab.items()}
self._command_token_tokens = list(self.command_token_map.keys())
self._command_token_vocab = {t:Id for Id,t in self.command_id_map.items()}
self._token_types = list(self.type_token_map.keys())
self._token_type_vocab = {t:Id for Id, t in self.type_id_map.items()}
def __call__(self, text, process_fn=None):
"""run preprocessing and encode text as Ids"""
return self.EncodeAsIds(text, process_fn=process_fn)
def __len__(self):
"""total number of tokens"""
return self.num_tokens
def get_command(self, name):
"""get command token corresponding to `name`"""
return self.command_name_map[name]
def get_type(self, name):
"""get type token corresponding to `name`"""
return self.type_name_map[name]
@property
def tokens(self):
"""list (or iterable) of all tokens for tokenizer"""
return self._tokens
@property
def vocab(self):
"""dictionary mapping tokens to ids for tokenizer"""
return self._vocab
@property
def token_types(self):
"""list (or iterable) of all token types for tokenizer"""
return self._token_types
@property
def token_type_vocab(self):
"""dictionary mapping token types to ids for tokenizer"""
return self._token_type_vocab
@property
def command_tokens(self):
"""list (or iterable) of all command tokens for tokenizer"""
return self._command_token_tokens
@property
def command_token_vocab(self):
"""dictionary mapping command tokens to ids for tokenizer"""
return self._command_token_vocab
@property
def text_tokens(self):
"""list (or iterable) of text tokens for text tokenizer"""
return self._text_tokens
@property
def text_token_vocab(self):
"""dictionary mapping text tokens to ids for text tokenizer"""
return self._text_token_vocab
def EncodeAsIds(self, text, process_fn=None):
"""
encode text using text tokenizer and shift Id values for command tokens
"""
tokenization = self.text_tokenizer.EncodeAsIds(text, process_fn=process_fn)
tokenization.tokenization = [t+self.num_command_tokens for t in tokenization.tokenization]
tokenization.set_command_tokens(self._command_tokens)
return tokenization
def EncodeAsTokens(self, text, process_fn=None):
"""
encode text as tokens using text tokenizer
"""
tokenization = self.text_tokenizer.EncodeAsTokens(text, process_fn=process_fn)
tokenization.set_command_tokens(self._command_tokens)
return tokenization
def IdToToken(self, Id, type_token=False):
"""convert Id to token accounting for command and type tokens"""
if isinstance(Id, (TypeToken, CommandToken)):
return Id.token
if type_token:
return self.type_id_map[Id].token
if Id < self.num_command_tokens:
return self.command_id_map[Id].token
return self.text_tokenizer.IdToToken(Id-self.num_command_tokens)
def TokenToId(self, token, type_token=False):
"""convert token to Id accounting for command and type tokens"""
if isinstance(token, (TypeToken, CommandToken)):
return token.Id
if type_token:
return self.type_token_map[token].Id
if token in self.command_token_map:
return self.command_token_map[token].Id
return self.text_tokenizer.TokenToId(token)+self.num_command_tokens
def DecodeIds(self, Ids, type_token=False):
"""
convert Ids to tokens accounting for command and type tokens, tokens
are joined and returned as a string.
"""
if type_token:
return ' '.join(Id.token if isinstance(Id, TypeToken) else self.type_id_map[Id].token for Id in Ids)
rtn_strs = []
current_str = []
if isinstance(Ids, Tokenization):
Ids = Ids.tokenization
for Id in Ids:
if isinstance(Id, CommandToken):
rtn_strs.append(self.text_tokenizer.DecodeIds(current_str))
current_str = []
rtn_strs.append(Id.token)
elif Id < self.num_command_tokens:
rtn_strs.append(self.text_tokenizer.DecodeIds(current_str))
current_str = []
rtn_strs.append(self.command_id_map[Id].token)
else:
current_str.append(Id - self.num_command_tokens)
if current_str != []:
rtn_strs.append(self.text_tokenizer.DecodeIds(current_str))
return ' '.join(rtn_strs)
def DecodeTokens(self, Tokens, type_token=False):
"""
convert tokens to a string accounting for command and type tokens.
"""
if type_token:
return ' '.join(t.token if isinstance(t, TypeToken) else t for t in Tokens)
rtn_strs = []
current_str = []
if isinstance(Tokens, Tokenization):
Tokens = Tokens.tokenization
for t in Tokens:
if isinstance(t, CommandToken):
rtn_strs.append(self.text_tokenizer.DecodeTokens(current_str))
current_str = []
rtn_strs.append(t.token)
elif t in self.command_token_map:
rtn_strs.append(self.text_tokenizer.DecodeTokens(current_str))
current_str = []
rtn_strs.append(t)
else:
current_str.append(t)
if current_str != []:
rtn_strs.append(self.text_tokenizer.DecodeTokens(current_str))
return ' '.join(rtn_strs)
class TextTokenizer(object):
"""
Interface for text tokenizer
"""
def __init__(self):
if not hasattr(self, 'num_text_tokens'):
self.num_text_tokens = 0
if not hasattr(self, 'num_tokens'):
self.num_tokens = self.num_text_tokens
def __call__(self, text, process_fn=None):
return self.EncodeAsIds(text, process_fn)
def __len__(self):
return self.num_text_tokens
@property
def tokens(self):
"""list (or iterable) of text tokens for text tokenizer"""
raise NotImplementedError('TextTokenizer tokens property not implemented')
@property
def vocab(self):
"""dictionary mapping tokens to ids"""
raise NotImplementedError('TextTokenizer vocab property not implemented')
@staticmethod
def exists(model_path):
"""check if the filepath for a text tokenizer exists"""
raise NotImplementedError('TextTokenizer exists method not implemented')
def Train(self, corpus):
"""train a tokenizer on a data corpus and save model for future use"""
raise NotImplementedError('TextTokenizer Train not implemented')
def EncodeAsIds(self, text, process_fn=None):
"""
Preprocess text and encode as ids. Return a tokenization object with
original text, processed text, and id tokenization.
"""
raise NotImplementedError('TextTokenizer EncodeAsIds not implemented')
def EncodeAsTokens(self, text, process_fn=None):
"""
Preprocess text and encode as tokens. Return a tokenization object with
original text, processed text, and token tokenization.
"""
raise NotImplementedError('TextTokenizer EncodeAsTokens not implemented')
def IdToToken(self, Id):
"""Convert an Id to Token. Reverse lookup of self.vocab"""
raise NotImplementedError('TextTokenizer IdToToken not implemented')
def TokenToId(self, token):
"""Convert a Token to Id. Lookup of self.vocab"""
raise NotImplementedError('TextTokenizer TokenToId not implemented')
def DecodeIds(self, Ids):
"""Convert a list or tokenization object of Ids to a text string"""
raise NotImplementedError('TextTokenizer DecodeIds not implemented')
def DecodeTokens(self, Tokens):
"""Convert a list or tokenization object of tokens to a text string"""
raise NotImplementedError('TextTokenizer DecodeTokens not implemented')
class CharacterLevelTokenizer(TextTokenizer):
"""
Text tokenizer for ASCII-256 Character Level Tokenization.
"""
def __init__(self, **kwargs):
self.num_text_tokens = 256
super(CharacterLevelTokenizer, self).__init__()
self._tokens = [self.IdToToken(Id) for Id in range(self.num_text_tokens)]
self._vocab = {t: i for i,t in enumerate(self._tokens)}
def __len__(self):
return 256
@staticmethod
def exists(model_path):
return True
def Train(self, corpus):
pass
@property
def tokens(self):
return self._tokens
@property
def vocab(self):
return self._vocab
def EncodeAsIds(self, text, process_fn=None):
"""convert text to ascii 256 Ids"""
processed_text = text
if process_fn is not None:
processed_text = process_fn(processed_text)
processed_text = str(processed_text)
tokens = [self.TokenToId(c) for c in processed_text]
return Tokenization(tokens, processed_text, text)
def EncodeAsTokens(self, text, process_fn=None):
"""convert text to ascii 256 characters"""
processed_text = text
if process_fn is not None:
processed_text = process_fn(processed_text)
processed_text = str(processed_text)
tokens = [c for c in processed_text]
return Tokenization(tokens, processed_text, text, asIds=False)
def IdToToken(self, Id):
"""ascii index to character"""
return chr(Id)
def TokenToId(self, token):
"""ascii character to index"""
return ord(token)
def DecodeIds(self, Ids):
"""converts ascii ids to tokens before joining them into text"""
if isinstance(Ids, Tokenization):
Ids = Ids.tokenization
return ''.join([self.IdToToken(tok) for tok in Ids])
def DecodeTokens(self, Tokens):
"""just concatenates ascii tokens into text"""
if isinstance(Tokens, Tokenization):
Tokens = Tokens.tokenization
return ''.join(Tokens)
MAX_SENTENCEPIECE_SENTENCES = 100000000
def get_corpus_freq(dataset, filepath, filetype='tsv'):
"""
Take corpus, split it into sentences, and extract word frequencies.
Write frequencies to `filepath` as a tsv. Only write the first
MAX_SENTENCEPIECE_SENTENCES most common words to the file.
"""
nltk.download('punkt', download_dir="./nltk")
if filetype == 'tsv':
delimiter = '\t'
else:
delimiter = ','
print("compute corpus frequency\n", flush=True)
total_sentence_count = 0
maxlen = 0
freqs = {}
for entry in dataset:
if isinstance(entry, dict):
entry = entry['text']
lines = entry.strip().split('\n')
for line in lines:
sentences = nltk_tokenize.sent_tokenize(line)
total_sentence_count += len(sentences)
for sentence in sentences:
maxlen = max(len(line), maxlen)
for word in sentence.split():
if word not in freqs:
freqs[word] = 0
freqs[word] += 1
print("length of freqs before truncating " + str(len(freqs)), flush=True)
print("file path for freq " + str(filepath), flush=True)
freqs_sorted = {}
counter=0
for word, count in sorted(freqs.items(), key=lambda x: x[1], reverse=True):
if counter >= MAX_SENTENCEPIECE_SENTENCES:
break
counter+=1
freqs_sorted[word] = count
print("length of freqs after trancating " + str(len(freqs_sorted)), flush=True)
with open(filepath, 'w') as f:
writer = csv.writer(f, delimiter=delimiter)
for k, v in freqs_sorted.items():
writer.writerow([str(k), str(v)])
return total_sentence_count, maxlen
class SentencePieceTokenizer(TextTokenizer):
"""Trains and uses sentencepiece for text tokenization"""
def __init__(self, model_type='bpe', vocab_size=None, corpus=None, model_path=None, character_coverage=1.0, **kwargs):
self.character_coverage = character_coverage
self.model_type = model_type.lower()
self.spm_model = model_path
self.num_text_tokens = vocab_size
make_train = not SentencePieceTokenizer.exists(self.spm_model)
if make_train:
assert corpus is not None and self.num_text_tokens is not None
self.Train(corpus, self.num_text_tokens)
self._tokens = []
self._vocab = {}
self.load_spm_model()
super(SentencePieceTokenizer, self).__init__()
def __len__(self):
return self.num_text_tokens
@property
def tokens(self):
return self._tokens
@property
def vocab(self):
return self._vocab
@staticmethod
def exists(model_path):
if model_path is None:
return False
# check if path exists
dne = not os.path.exists(model_path)
# check if path.model exists
if dne and not model_path.endswith('.model'):
dne = not os.path.exists(model_path+'.model')
return not dne
def load_spm_model(self):
"""load sentencepiece model and parse vocab"""
if not os.path.exists(self.spm_model) and not self.spm_model.endswith('.model'):
self.spm_model = self.spm_model+'.model'
self.sp = spm.SentencePieceProcessor()
self.sp.Load(self.spm_model)
self.vocab_size = self.num_text_tokens = len(self.sp)
self._tokens = [self.IdToToken(t) for t in range(self.vocab_size)]
self._vocab = {t: i for i,t in enumerate(self._tokens)}
def Train(self, corpus, num_text_tokens):
"""train sentencepiece model on corpus using word frequencies"""
self.num_text_tokens = num_text_tokens
use_model_path = self.spm_model
random_hash = str(random.randint(0, 2147483647))
if use_model_path is None:
use_model_path = random_hash
if use_model_path.endswith('.model'):
use_model_path = use_model_path[:use_model_path.rfind('.model')]
input_path = use_model_path+'.tsv.'+random_hash
line_count, maxlenline = get_corpus_freq(corpus, input_path)
line_count = min(line_count, MAX_SENTENCEPIECE_SENTENCES)
print('line count used as input_sentence_size ', line_count, flush=True)
print('training sentencepiece model', flush=True)
train_string = '--input={file_path} --model_prefix={model_prefix} --vocab_size={vocab_size}' \
+ ' --model_type={model_type} --character_coverage={character_coverage} ' \
+ '--input_sentence_size={input_sentence_size} ' \
+ '--input_format=tsv'
train_string = train_string.format(file_path=input_path, model_prefix=use_model_path, vocab_size=num_text_tokens,
model_type=self.model_type, character_coverage=self.character_coverage,
input_sentence_size=int(line_count)) #, #)#,
print("calling spm.SentencePieceTrainer.Train(%s)"%(train_string), flush=True)
spm.SentencePieceTrainer.Train(train_string)
os.remove(input_path)
self.spm_model = use_model_path+'.model'
print('sentencepiece model written to '+self.spm_model, flush=True)
def EncodeAsIds(self, text, process_fn=None):
"""convert text to sentencepiece Ids"""
processed_text = text
if process_fn is not None:
processed_text = process_fn(processed_text)
tokens = self.sp.EncodeAsIds(processed_text)
return Tokenization(tokens, processed_text, text)
def EncodeAsTokens(self, text, process_fn=None):
"""convert text to sentencepiece tokens"""
processed_text = text
if process_fn is not None:
processed_text = process_fn(processed_text)
tokens = self.sp.EncodeAsTokens(processed_text)
return Tokenization(tokens, processed_text, text, asIds=False)
def IdToToken(self, Id):
"""convert Id to sentencpiece token"""
return self.sp.IdToPiece(Id)
def TokenToId(self, token):
"""convert sentencpiece token to Id"""
return self.sp.PieceToId(token)
def DecodeIds(self, Ids):
"""converts ids to a text string"""
if isinstance(Ids, Tokenization):
Ids = Ids.tokenization
return self.sp.DecodeIds(Ids)
def DecodeTokens(self, Tokens):
"""converts sentencepiece tokens to a text string"""
if isinstance(Tokens, Tokenization):
Tokens = Tokens.tokenization
return self.sp.DecodeTokens(Tokens)
class BertWordPieceTokenizer(Tokenizer):
"""
Loads a pretrained WordPiece tokenizer from `cache_dir` for tokenization
in BERT training. Default to bert-large-uncased tokenizer.
"""
def __init__(self, tokenizer_model_type=None, cache_dir=None, **kwargs):
# default to bert-large-uncased tokenizer
if tokenizer_model_type not in PRETRAINED_VOCAB_ARCHIVE_MAP:
tokenizer_model_type = 'bert-large-uncased'
if not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0:
print('loading BertWordPieceTokenizer (', tokenizer_model_type, ') from cache_dir ', cache_dir)
do_lower_case = not ('-cased' in tokenizer_model_type or 'chinese' in tokenizer_model_type)
self.text_tokenizer = BertTokenizer.from_pretrained(tokenizer_model_type, do_lower_case=do_lower_case, cache_dir=cache_dir)
if not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0:
print('loaded', tokenizer_model_type)
# disable max len warnings by increasing max len
self.text_tokenizer.max_len = int(1e12)
# set command tokens from wordpiece tokenizer values
self.num_command_tokens = 5
self.num_tokens = len(self.text_tokenizer.vocab)
self.num_text_tokens = self.num_tokens-5
self.num_type_tokens = 2
self._command_tokens = [
CommandToken('pad', '[PAD]', self.text_tokenizer.vocab['[PAD]']),
CommandToken('ENC', '[CLS]', self.text_tokenizer.vocab['[CLS]']),
CommandToken('MASK', '[MASK]', self.text_tokenizer.vocab['[MASK]']),
CommandToken('unk', '[UNK]', self.text_tokenizer.vocab['[UNK]']),
CommandToken('sep', '[SEP]', self.text_tokenizer.vocab['[SEP]']),
]
self.command_name_map = {tok.name: tok for tok in self._command_tokens}
self.command_token_map = {tok.token: tok for tok in self._command_tokens}
self.command_id_map = {tok.Id: tok for tok in self._command_tokens}
# set type tokens
self.type_tokens = [
TypeToken('str0', '<str0>', 0),
TypeToken('str1', '<str1>', 1),
]
self.type_name_map = {tok.name: tok for tok in self.type_tokens}
self.type_token_map = {tok.token: tok for tok in self.type_tokens}
self.type_id_map = {tok.Id: tok for tok in self.type_tokens}
# parse tokens and vocabs from tokenizer
self._tokens = list(self.text_tokenizer.vocab.keys())
self._vocab = {k:v for k,v in self.text_tokenizer.vocab.items()}
self._text_tokens = list(self._tokens)
self._text_token_vocab = {k:v for k,v in self.text_tokenizer.vocab.items()}
self._command_token_tokens = list(self.command_token_map.keys())
self._command_token_vocab = {t:Id for Id,t in self.command_id_map.items()}
self._token_types = list(self.type_token_map.keys())
self._token_type_vocab = {t:Id for Id, t in self.type_id_map.items()}
def EncodeAsIds(self, text, process_fn=None):
"""convert text to wordpiece Ids"""
processed_text = text
if process_fn is not None:
processed_text = process_fn(processed_text)
tokens = self.text_tokenizer.tokenize(processed_text)
Ids = self.text_tokenizer.convert_tokens_to_ids(tokens)
return Tokenization(Ids, processed_text, text)
def EncodeAsTokens(self, text, process_fn=None):
"""convert wordpiece token to Id"""
processed_text = text
if process_fn is not None:
processed_text = process_fn(processed_text)
tokens = self.text_tokenizer.tokenize(processed_text)
return Tokenization(tokens, processed_text, text, asIds=False)
def IdToToken(self, Id, type_token=False):
"""convert Id to sentencpiece token"""
if isinstance(Id, (TypeToken, CommandToken)):
return Id.token
if type_token:
return self.type_id_map[Id].token
return self.text_tokenizer.ids_to_tokens[Id]
def TokenToId(self, token, type_token=False):
"""convert sentencpiece token to Id"""
if isinstance(token, (TypeToken, CommandToken)):
return token.Id
if type_token:
return self.type_token_map[token].Id
return self.text_tokenizer.vocab[token]
def DecodeIds(self, Ids, type_token=False):
"""converts ids to wordpiece tokens and joins them as a text string"""
if type_token:
return ' '.join(Id.token if isinstance(Id, TypeToken) else self.type_id_map[Id].token for Id in Ids)
if isinstance(Ids, Tokenization):
Ids = Ids.tokenization
Tokens = []
for Id in Ids:
Tokens.append(self.text_tokenizer.ids_to_tokens[Id] if Id != -1 else '-1')
Tokens = self.text_tokenizer.convert_ids_to_tokens(Ids)
return ' '.join(Tokens)
def DecodeTokens(self, Tokens, type_token=False):
"""converts wordpiece tokens to a text string"""
if type_token:
return ' '.join(t.token if isinstance(t, TypeToken) else t for t in Tokens)
if isinstance(Tokens, Tokenization):
Tokens = Tokens.tokenization
return ' '.join(Tokens)
class GPT2BPETokenizer(Tokenizer):
def __init__(self, cache_dir=None, add_eop=False, **kwargs):
self.text_tokenizer = GPT2Tokenizer.from_pretrained('gpt2',
cache_dir=cache_dir)
#disable max len warnings by increasing max len
self.text_tokenizer.max_len = int(1e12)
self.num_command_tokens = 2
self.num_tokens = len(self.text_tokenizer.encoder)
self.num_text_tokens = self.num_tokens - 1
self.num_type_tokens = 2
self._command_tokens = [
CommandToken('pad', '<|endoftext|>', self.text_tokenizer.encoder['<|endoftext|>']),
CommandToken('eos', '<|endoftext|>', self.text_tokenizer.encoder['<|endoftext|>'])
]
if add_eop:
self._command_tokens.append(CommandToken('eop', '<|endofpiece|>', self.num_tokens))
self.num_tokens += 1
self.num_command_tokens += 1
self.command_name_map = {tok.name: tok for tok in self._command_tokens}
self.command_token_map = {tok.token: tok for tok in self._command_tokens}
self.command_id_map = {tok.Id: tok for tok in self._command_tokens}
self.type_tokens = [
TypeToken('str0', '<str0>', 0),
TypeToken('str1', '<str1>', 1),
]
self.type_name_map = {tok.name: tok for tok in self.type_tokens}
self.type_token_map = {tok.token: tok for tok in self.type_tokens}
self.type_id_map = {tok.Id: tok for tok in self.type_tokens}
self._tokens = list(self.text_tokenizer.encoder.keys())
self._vocab = {k:v for k,v in self.text_tokenizer.encoder.items()}
self._text_tokens = list(self._tokens)
self._text_token_vocab = {k:v for k,v in self.text_tokenizer.encoder.items()}
self._command_token_tokens = list(self.command_token_map.keys())
self._command_token_vocab = {t:Id for Id,t in self.command_id_map.items()}
self._token_types = list(self.type_token_map.keys())
self._token_type_vocab = {t:Id for Id, t in self.type_id_map.items()}
def EncodeAsIds(self, text, process_fn=None):
processed_text = text
if process_fn is not None:
processed_text = process_fn(processed_text)
Ids = self.text_tokenizer.encode(processed_text)
#return Tokenization(Ids, processed_text, text)
tokenization = Tokenization(Ids, processed_text, text)
tokenization.set_command_tokens(self._command_tokens)
return tokenization
def EncodeAsTokens(self, text, process_fn=None):
processed_text = text
if process_fn is not None:
processed_text = process_fn(processed_text)
tokens = []
for token in re.findall(self.text_tokenizer.pat, processed_text):
token = ''.join(self.text_tokenizer.bye_encoder[b] for b in token.encode('utf-8'))
tokens.extend(bpe_token for bpe_token in self.text_tokenizer.bpe(token).split(' '))
tokenization=Tokenization(tokens, processed_text, text, asIds=False)
tokenization.set_command_tokens(self._command_tokens)
return tokenization
#return Tokenization(tokens, processed_text, text, asIds=False)
def IdToToken(self, Id, type_token=False):
if isinstance(Id, (TypeToken, CommandToken)):
return Id.token
if type_token:
return self.type_id_map[Id].token
return self.text_tokenizer.decoder[Id]
def TokenToId(self, token, type_token=False):
if isinstance(token, (TypeToken, CommandToken)):
return token.Id
if type_token:
return self.type_token_map[token].Id
return self.text_tokenizer.encoder[token]
def DecodeIds(self, Ids, type_token=False):
if type_token:
return ' '.join(Id.token if isinstance(Id, TypeToken) else self.type_id_map[Id].token for Id in Ids)
if isinstance(Ids, Tokenization):
Ids = Ids.tokenization
return self.text_tokenizer.decode(Ids)
def DecodeTokens(self, Tokens, type_token=False):
if type_token:
return ' '.join(t.token if isinstance(t, TypeToken) else t for t in Tokens)
if isinstance(Tokens, Tokenization):
Tokens = Tokens.tokenization
return self.text_tokenizer.decode([self.TokenToId(tok) for tok in Tokens])
class ChineseSPTokenizer(Tokenizer):
def __init__(self, add_eop=False, **kwargs):
self.text_tokenizer = sp_tokenizer.from_pretrained()
self.num_command_tokens = 2
self.num_text_tokens = self.text_tokenizer.sp.vocab_size()
self.num_tokens = self.num_text_tokens + 1
self.num_type_tokens = 2
self._command_tokens = [
CommandToken('pad', '<|endoftext|>', self.num_text_tokens),
CommandToken('eos', '<|endoftext|>', self.num_text_tokens),
]
if add_eop:
self._command_tokens.append(CommandToken('eop', '<|endofpiece|>', self.num_text_tokens + 1))
self.num_tokens += 1
self.num_command_tokens += 1
self.command_name_map = {tok.name: tok for tok in self._command_tokens}
self.command_token_map = {tok.token: tok for tok in self._command_tokens}
self.command_id_map = {tok.Id: tok for tok in self._command_tokens}
self.type_tokens = [
TypeToken('str0', '<str0>', 0),
TypeToken('str1', '<str1>', 1),
]
self.type_name_map = {tok.name: tok for tok in self.type_tokens}
self.type_token_map = {tok.token: tok for tok in self.type_tokens}
self.type_id_map = {tok.Id: tok for tok in self.type_tokens}
# self._tokens = list(self.text_tokenizer.encoder.keys())
# self._vocab = {k:v for k,v in self.text_tokenizer.encoder.items()}
#
# self._text_tokens = list(self._tokens)
# self._text_token_vocab = {k:v for k,v in self.text_tokenizer.encoder.items()}
self._command_token_tokens = list(self.command_token_map.keys())
self._command_token_vocab = {t:Id for Id,t in self.command_id_map.items()}
self._token_types = list(self.type_token_map.keys())
self._token_type_vocab = {t:Id for Id, t in self.type_id_map.items()}
def EncodeAsIds(self, text, process_fn=None):
processed_text = text
if process_fn is not None:
processed_text = process_fn(processed_text)
Ids = self.text_tokenizer.encode(processed_text)
#return Tokenization(Ids, processed_text, text)
tokenization = Tokenization(Ids, processed_text, text)
tokenization.set_command_tokens(self._command_tokens)
return tokenization
def EncodeAsTokens(self, text, process_fn=None):
processed_text = text
if process_fn is not None:
processed_text = process_fn(processed_text)
tokens = self.text_tokenizer.tokenize(processed_text)
tokenization = Tokenization(tokens, processed_text, text, asIds=False)
tokenization.set_command_tokens(self._command_tokens)
return tokenization
#return Tokenization(tokens, processed_text, text, asIds=False)
def IdToToken(self, Id, type_token=False):
if isinstance(Id, (TypeToken, CommandToken)):
return Id.token
if type_token:
return self.type_id_map[Id].token
if Id in self.command_id_map:
return self.command_id_map[Id].token
elif Id in self.type_id_map:
return self.type_id_map[Id].token
else:
return self.text_tokenizer.convert_id_to_token(Id)
def TokenToId(self, token, type_token=False):
if isinstance(token, (TypeToken, CommandToken)):
return token.Id
if type_token:
return self.type_token_map[token].Id
return self.text_tokenizer.convert_token_to_id(token)
def DecodeIds(self, Ids, type_token=False):
if type_token:
return ' '.join(Id.token if isinstance(Id, TypeToken) else self.type_id_map[Id].token for Id in Ids)
if isinstance(Ids, Tokenization):
Ids = Ids.tokenization
try:
first_eos = Ids.index(self.get_command('eos').Id)
eos_count = len(Ids) - first_eos
Ids = Ids[:first_eos]
except ValueError:
eos_count = 0
return " ".join((self.text_tokenizer.decode(Ids), *(['<|endoftext|>']*eos_count)))
def DecodeTokens(self, Tokens, type_token=False):
if type_token:
return ' '.join(t.token if isinstance(t, TypeToken) else t for t in Tokens)
if isinstance(Tokens, Tokenization):
Tokens = Tokens.tokenization
return self.text_tokenizer.decode([self.TokenToId(tok) for tok in Tokens])
\ No newline at end of file
# coding=utf-8
# Copyright 2018 The Open AI Team Authors and The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tokenization classes for OpenAI GPT."""
from __future__ import (absolute_import, division, print_function,
unicode_literals)
import sys
import json
import logging
import os
import regex as re
from io import open
try:
from functools import lru_cache
except ImportError:
# Just a dummy decorator to get the checks to run on python2
# because honestly I don't want to support a byte-level unicode BPE tokenizer on python 2 right now.
def lru_cache():
return lambda func: func
from .file_utils import cached_path
logger = logging.getLogger(__name__)
PRETRAINED_VOCAB_ARCHIVE_MAP = {
# 'gpt2': "/workspace/.pytorch_pretrained_bert/gpt2-vocab.json",
'gpt2': ".pytorch_pretrained_bert/gpt2-vocab.json",
}
PRETRAINED_MERGES_ARCHIVE_MAP = {
# 'gpt2': "/workspace/.pytorch_pretrained_bert/gpt2-merges.txt",
'gpt2': ".pytorch_pretrained_bert/gpt2-merges.txt",
}
PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP = {
'gpt2': 1024,
}
VOCAB_NAME = 'vocab.json'
MERGES_NAME = 'merges.txt'
SPECIAL_TOKENS_NAME = 'special_tokens.txt'
@lru_cache()
def bytes_to_unicode():
"""
Returns list of utf-8 byte and a corresponding list of unicode strings.
The reversible bpe codes work on unicode strings.
This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
This is a signficant percentage of your normal, say, 32K bpe vocab.
To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
And avoids mapping to whitespace/control characters the bpe code barfs on.
"""
_chr = unichr if sys.version_info[0] == 2 else chr
bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
cs = bs[:]
n = 0
for b in range(2**8):
if b not in bs:
bs.append(b)
cs.append(2**8+n)
n += 1
cs = [_chr(n) for n in cs]
return dict(zip(bs, cs))
def get_pairs(word):
"""Return set of symbol pairs in a word.
Word is represented as tuple of symbols (symbols being variable-length strings).
"""
pairs = set()
prev_char = word[0]
for char in word[1:]:
pairs.add((prev_char, char))
prev_char = char
return pairs
class GPT2Tokenizer(object):
"""
GPT-2 BPE tokenizer. Peculiarities:
- Byte-level BPE
"""
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, cache_dir=None, *inputs, **kwargs):
"""
Instantiate a PreTrainedBertModel from a pre-trained model file.
Download and cache the pre-trained model file if needed.
"""
if pretrained_model_name_or_path in PRETRAINED_VOCAB_ARCHIVE_MAP:
vocab_file = PRETRAINED_VOCAB_ARCHIVE_MAP[pretrained_model_name_or_path]
merges_file = PRETRAINED_MERGES_ARCHIVE_MAP[pretrained_model_name_or_path]
special_tokens_file = None
else:
vocab_file = os.path.join(pretrained_model_name_or_path, VOCAB_NAME)
merges_file = os.path.join(pretrained_model_name_or_path, MERGES_NAME)
special_tokens_file = os.path.join(pretrained_model_name_or_path, SPECIAL_TOKENS_NAME)
if not os.path.exists(special_tokens_file):
special_tokens_file = None
else:
logger.info("loading special tokens file {}".format(special_tokens_file))
# redirect to the cache, if necessary
# try:
# resolved_vocab_file = cached_path(vocab_file, cache_dir=cache_dir)
# resolved_merges_file = cached_path(merges_file, cache_dir=cache_dir)
# except EnvironmentError:
# logger.error(
# "Model name '{}' was not found in model name list ({}). "
# "We assumed '{}' was a path or url but couldn't find files {} and {} "
# "at this path or url.".format(
# pretrained_model_name_or_path,
# ', '.join(PRETRAINED_VOCAB_ARCHIVE_MAP.keys()),
# pretrained_model_name_or_path,
# vocab_file, merges_file))
# return None
# if resolved_vocab_file == vocab_file and resolved_merges_file == merges_file:
# logger.info("loading vocabulary file {}".format(vocab_file))
# logger.info("loading merges file {}".format(merges_file))
# else:
# logger.info("loading vocabulary file {} from cache at {}".format(
# vocab_file, resolved_vocab_file))
# logger.info("loading merges file {} from cache at {}".format(
# merges_file, resolved_merges_file))
resolved_vocab_file = vocab_file
resolved_merges_file = merges_file
logger.info("loading vocabulary file {}".format(vocab_file))
logger.info("loading merges file {}".format(merges_file))
if pretrained_model_name_or_path in PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP:
# if we're using a pretrained model, ensure the tokenizer wont index sequences longer
# than the number of positional embeddings
max_len = PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP[pretrained_model_name_or_path]
kwargs['max_len'] = min(kwargs.get('max_len', int(1e12)), max_len)
# Instantiate tokenizer.
if special_tokens_file and 'special_tokens' not in kwargs:
special_tokens = open(special_tokens_file, encoding='utf-8').read().split('\n')[:-1]
else:
special_tokens = kwargs.pop('special_tokens', [])
tokenizer = cls(resolved_vocab_file, resolved_merges_file, special_tokens=special_tokens, *inputs, **kwargs)
return tokenizer
def __init__(self, vocab_file, merges_file, errors='replace', special_tokens=None, max_len=None):
self.max_len = max_len if max_len is not None else int(1e12)
self.encoder = json.load(open(vocab_file))
self.decoder = {v:k for k,v in self.encoder.items()}
self.errors = errors # how to handle errors in decoding
self.byte_encoder = bytes_to_unicode()
self.byte_decoder = {v:k for k, v in self.byte_encoder.items()}
bpe_data = open(merges_file, encoding='utf-8').read().split('\n')[1:-1]
bpe_merges = [tuple(merge.split()) for merge in bpe_data]
self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges))))
self.cache = {}
# Should haved added re.IGNORECASE so BPE merges can happen for capitalized versions of contractions
self.pat = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""")
self.special_tokens = {}
self.special_tokens_decoder = {}
self.set_special_tokens(special_tokens)
def __len__(self):
return len(self.encoder) + len(self.special_tokens)
def set_special_tokens(self, special_tokens):
""" Add a list of additional tokens to the encoder.
The additional tokens are indexed starting from the last index of the
current vocabulary in the order of the `special_tokens` list.
"""
if not special_tokens:
self.special_tokens = {}
self.special_tokens_decoder = {}
return
self.special_tokens = dict((tok, len(self.encoder) + i) for i, tok in enumerate(special_tokens))
self.special_tokens_decoder = {v:k for k, v in self.special_tokens.items()}
logger.info("Special tokens {}".format(self.special_tokens))
def bpe(self, token):
if token in self.cache:
return self.cache[token]
word = tuple(token)
pairs = get_pairs(word)
if not pairs:
return token
while True:
bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
if bigram not in self.bpe_ranks:
break
first, second = bigram
new_word = []
i = 0
while i < len(word):
try:
j = word.index(first, i)
new_word.extend(word[i:j])
i = j
except:
new_word.extend(word[i:])
break
if word[i] == first and i < len(word)-1 and word[i+1] == second:
new_word.append(first+second)
i += 2
else:
new_word.append(word[i])
i += 1
new_word = tuple(new_word)
word = new_word
if len(word) == 1:
break
else:
pairs = get_pairs(word)
word = ' '.join(word)
self.cache[token] = word
return word
def tokenize(self, text):
""" Tokenize a string. """
bpe_tokens = []
for token in re.findall(self.pat, text):
if sys.version_info[0] == 2:
token = ''.join(self.byte_encoder[ord(b)] for b in token)
else:
token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
bpe_tokens.extend(bpe_token for bpe_token in self.bpe(token).split(' '))
return bpe_tokens
def convert_tokens_to_ids(self, tokens):
""" Converts a sequence of tokens into ids using the vocab. """
ids = []
if isinstance(tokens, str) or (sys.version_info[0] == 2 and isinstance(tokens, unicode)):
if tokens in self.special_tokens:
return self.special_tokens[tokens]
else:
return self.encoder.get(tokens, 0)
for token in tokens:
if token in self.special_tokens:
ids.append(self.special_tokens[token])
else:
ids.append(self.encoder.get(token, 0))
if len(ids) > self.max_len:
logger.warning(
"Token indices sequence length is longer than the specified maximum "
" sequence length for this OpenAI GPT model ({} > {}). Running this"
" sequence through the model will result in indexing errors".format(len(ids), self.max_len)
)
return ids
def convert_ids_to_tokens(self, ids, skip_special_tokens=False):
"""Converts a sequence of ids in BPE tokens using the vocab."""
tokens = []
for i in ids:
if i in self.special_tokens_decoder:
if not skip_special_tokens:
tokens.append(self.special_tokens_decoder[i])
else:
tokens.append(self.decoder[i])
return tokens
def encode(self, text):
return self.convert_tokens_to_ids(self.tokenize(text))
def decode(self, tokens):
text = ''.join([self.decoder[token] for token in tokens])
text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors=self.errors)
return text
def save_vocabulary(self, vocab_path):
"""Save the tokenizer vocabulary and merge files to a directory."""
if not os.path.isdir(vocab_path):
logger.error("Vocabulary path ({}) should be a directory".format(vocab_path))
return
vocab_file = os.path.join(vocab_path, VOCAB_NAME)
merge_file = os.path.join(vocab_path, MERGES_NAME)
special_tokens_file = os.path.join(vocab_path, SPECIAL_TOKENS_NAME)
with open(vocab_file, 'w', encoding='utf-8') as f:
f.write(json.dumps(self.encoder, ensure_ascii=False))
index = 0
with open(merge_file, "w", encoding="utf-8") as writer:
writer.write(u'#version: 0.2\n')
for bpe_tokens, token_index in sorted(self.bpe_ranks.items(), key=lambda kv: kv[1]):
if index != token_index:
logger.warning("Saving vocabulary to {}: BPE merge indices are not consecutive."
" Please check that the tokenizer is not corrupted!".format(merge_file))
index = token_index
writer.write(' '.join(bpe_tokens) + u'\n')
index += 1
index = len(self.encoder)
with open(special_tokens_file, 'w', encoding='utf-8') as writer:
for token, token_index in sorted(self.special_tokens.items(), key=lambda kv: kv[1]):
if index != token_index:
logger.warning("Saving special tokens vocabulary to {}: BPE indices are not consecutive."
" Please check that the tokenizer is not corrupted!".format(special_tokens_file))
index = token_index
writer.write(token + u'\n')
index += 1
return vocab_file, merge_file, special_tokens_file
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tokenization classes. Provided as is from https://github.com/huggingface/pytorch-pretrained-BERT/blob/master/pytorch_pretrained_bert/tokenization.py"""
from __future__ import absolute_import, division, print_function, unicode_literals
import collections
import logging
import os
import unicodedata
from io import open
from .file_utils import cached_path
logger = logging.getLogger(__name__)
PRETRAINED_VOCAB_ARCHIVE_MAP = {
'bert-base-uncased': ".pytorch_pretrained_bert/bert-base-uncased-vocab.txt",
'bert-large-uncased': ".pytorch_pretrained_bert/bert-large-uncased-vocab.txt",
'bert-base-cased': ".pytorch_pretrained_bert/bert-base-cased-vocab.txt",
'bert-large-cased': ".pytorch_pretrained_bert/bert/bert-large-cased-vocab.txt",
'bert-base-multilingual-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased-vocab.txt",
'bert-base-multilingual-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-vocab.txt",
'bert-base-chinese': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-vocab.txt",
}
PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP = {
'bert-base-uncased': 512,
'bert-large-uncased': 512,
'bert-base-cased': 512,
'bert-large-cased': 512,
'bert-base-multilingual-uncased': 512,
'bert-base-multilingual-cased': 512,
'bert-base-chinese': 512,
}
VOCAB_NAME = 'vocab.txt'
def load_vocab(vocab_file):
"""Loads a vocabulary file into a dictionary."""
vocab = collections.OrderedDict()
index = 0
with open(vocab_file, "r", encoding="utf-8") as reader:
while True:
token = reader.readline()
if not token:
break
token = token.strip()
vocab[token] = index
index += 1
return vocab
def whitespace_tokenize(text):
"""Runs basic whitespace cleaning and splitting on a piece of text."""
text = text.strip()
if not text:
return []
tokens = text.split()
return tokens
class BertTokenizer(object):
"""Runs end-to-end tokenization: punctuation splitting + wordpiece"""
def __init__(self, vocab_file, do_lower_case=True, max_len=None, do_basic_tokenize=True,
never_split=("[UNK]", "[SEP]", "[PAD]", "[CLS]", "[MASK]")):
"""Constructs a BertTokenizer.
Args:
vocab_file: Path to a one-wordpiece-per-line vocabulary file
do_lower_case: Whether to lower case the input
Only has an effect when do_wordpiece_only=False
do_basic_tokenize: Whether to do basic tokenization before wordpiece.
max_len: An artificial maximum length to truncate tokenized sequences to;
Effective maximum length is always the minimum of this
value (if specified) and the underlying BERT model's
sequence length.
never_split: List of tokens which will never be split during tokenization.
Only has an effect when do_wordpiece_only=False
"""
if not os.path.isfile(vocab_file):
raise ValueError(
"Can't find a vocabulary file at path '{}'. To load the vocabulary from a Google pretrained "
"model use `tokenizer = BertTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`".format(vocab_file))
self.vocab = load_vocab(vocab_file)
self.ids_to_tokens = collections.OrderedDict(
[(ids, tok) for tok, ids in self.vocab.items()])
self.do_basic_tokenize = do_basic_tokenize
if do_basic_tokenize:
self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case,
never_split=never_split)
self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab)
self.max_len = max_len if max_len is not None else int(1e12)
def tokenize(self, text):
if self.do_basic_tokenize:
split_tokens = []
for token in self.basic_tokenizer.tokenize(text):
for sub_token in self.wordpiece_tokenizer.tokenize(token):
split_tokens.append(sub_token)
else:
split_tokens = self.wordpiece_tokenizer.tokenize(text)
return split_tokens
def convert_tokens_to_ids(self, tokens):
"""Converts a sequence of tokens into ids using the vocab."""
ids = []
for token in tokens:
ids.append(self.vocab[token])
if len(ids) > self.max_len:
logger.warning(
"Token indices sequence length is longer than the specified maximum "
" sequence length for this BERT model ({} > {}). Running this"
" sequence through BERT will result in indexing errors".format(len(ids), self.max_len)
)
return ids
def convert_ids_to_tokens(self, ids):
"""Converts a sequence of ids in wordpiece tokens using the vocab."""
tokens = []
for i in ids:
tokens.append(self.ids_to_tokens[i])
return tokens
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, cache_dir=None, *inputs, **kwargs):
"""
Instantiate a PreTrainedBertModel from a pre-trained model file.
Download and cache the pre-trained model file if needed.
"""
if pretrained_model_name_or_path in PRETRAINED_VOCAB_ARCHIVE_MAP:
vocab_file = PRETRAINED_VOCAB_ARCHIVE_MAP[pretrained_model_name_or_path]
else:
vocab_file = pretrained_model_name_or_path
if os.path.isdir(vocab_file):
vocab_file = os.path.join(vocab_file, VOCAB_NAME)
# redirect to the cache, if necessary
try:
resolved_vocab_file = cached_path(vocab_file, cache_dir=cache_dir)
except EnvironmentError:
logger.error(
"Model name '{}' was not found in model name list ({}). "
"We assumed '{}' was a path or url but couldn't find any file "
"associated to this path or url.".format(
pretrained_model_name_or_path,
', '.join(PRETRAINED_VOCAB_ARCHIVE_MAP.keys()),
vocab_file))
return None
if resolved_vocab_file == vocab_file:
logger.info("loading vocabulary file {}".format(vocab_file))
else:
logger.info("loading vocabulary file {} from cache at {}".format(
vocab_file, resolved_vocab_file))
if pretrained_model_name_or_path in PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP:
# if we're using a pretrained model, ensure the tokenizer wont index sequences longer
# than the number of positional embeddings
max_len = PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP[pretrained_model_name_or_path]
kwargs['max_len'] = min(kwargs.get('max_len', int(1e12)), max_len)
# Instantiate tokenizer.
tokenizer = cls(resolved_vocab_file, *inputs, **kwargs)
return tokenizer
class BasicTokenizer(object):
"""Runs basic tokenization (punctuation splitting, lower casing, etc.)."""
def __init__(self,
do_lower_case=True,
never_split=("[UNK]", "[SEP]", "[PAD]", "[CLS]", "[MASK]")):
"""Constructs a BasicTokenizer.
Args:
do_lower_case: Whether to lower case the input.
"""
self.do_lower_case = do_lower_case
self.never_split = never_split
def tokenize(self, text):
"""Tokenizes a piece of text."""
text = self._clean_text(text)
# This was added on November 1st, 2018 for the multilingual and Chinese
# models. This is also applied to the English models now, but it doesn't
# matter since the English models were not trained on any Chinese data
# and generally don't have any Chinese data in them (there are Chinese
# characters in the vocabulary because Wikipedia does have some Chinese
# words in the English Wikipedia.).
text = self._tokenize_chinese_chars(text)
orig_tokens = whitespace_tokenize(text)
split_tokens = []
for token in orig_tokens:
if self.do_lower_case and token not in self.never_split:
token = token.lower()
token = self._run_strip_accents(token)
split_tokens.extend(self._run_split_on_punc(token))
output_tokens = whitespace_tokenize(" ".join(split_tokens))
return output_tokens
def _run_strip_accents(self, text):
"""Strips accents from a piece of text."""
text = unicodedata.normalize("NFD", text)
output = []
for char in text:
cat = unicodedata.category(char)
if cat == "Mn":
continue
output.append(char)
return "".join(output)
def _run_split_on_punc(self, text):
"""Splits punctuation on a piece of text."""
if text in self.never_split:
return [text]
chars = list(text)
i = 0
start_new_word = True
output = []
while i < len(chars):
char = chars[i]
if _is_punctuation(char):
output.append([char])
start_new_word = True
else:
if start_new_word:
output.append([])
start_new_word = False
output[-1].append(char)
i += 1
return ["".join(x) for x in output]
def _tokenize_chinese_chars(self, text):
"""Adds whitespace around any CJK character."""
output = []
for char in text:
cp = ord(char)
if self._is_chinese_char(cp):
output.append(" ")
output.append(char)
output.append(" ")
else:
output.append(char)
return "".join(output)
def _is_chinese_char(self, cp):
"""Checks whether CP is the codepoint of a CJK character."""
# This defines a "chinese character" as anything in the CJK Unicode block:
# https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
#
# Note that the CJK Unicode block is NOT all Japanese and Korean characters,
# despite its name. The modern Korean Hangul alphabet is a different block,
# as is Japanese Hiragana and Katakana. Those alphabets are used to write
# space-separated words, so they are not treated specially and handled
# like the all of the other languages.
if ((cp >= 0x4E00 and cp <= 0x9FFF) or #
(cp >= 0x3400 and cp <= 0x4DBF) or #
(cp >= 0x20000 and cp <= 0x2A6DF) or #
(cp >= 0x2A700 and cp <= 0x2B73F) or #
(cp >= 0x2B740 and cp <= 0x2B81F) or #
(cp >= 0x2B820 and cp <= 0x2CEAF) or
(cp >= 0xF900 and cp <= 0xFAFF) or #
(cp >= 0x2F800 and cp <= 0x2FA1F)): #
return True
return False
def _clean_text(self, text):
"""Performs invalid character removal and whitespace cleanup on text."""
output = []
for char in text:
cp = ord(char)
if cp == 0 or cp == 0xfffd or _is_control(char):
continue
if _is_whitespace(char):
output.append(" ")
else:
output.append(char)
return "".join(output)
class WordpieceTokenizer(object):
"""Runs WordPiece tokenization."""
def __init__(self, vocab, unk_token="[UNK]", max_input_chars_per_word=100):
self.vocab = vocab
self.unk_token = unk_token
self.max_input_chars_per_word = max_input_chars_per_word
def tokenize(self, text):
"""Tokenizes a piece of text into its word pieces.
This uses a greedy longest-match-first algorithm to perform tokenization
using the given vocabulary.
For example:
input = "unaffable"
output = ["un", "##aff", "##able"]
Args:
text: A single token or whitespace separated tokens. This should have
already been passed through `BasicTokenizer`.
Returns:
A list of wordpiece tokens.
"""
output_tokens = []
for token in whitespace_tokenize(text):
chars = list(token)
if len(chars) > self.max_input_chars_per_word:
output_tokens.append(self.unk_token)
continue
is_bad = False
start = 0
sub_tokens = []
while start < len(chars):
end = len(chars)
cur_substr = None
while start < end:
substr = "".join(chars[start:end])
if start > 0:
substr = "##" + substr
if substr in self.vocab:
cur_substr = substr
break
end -= 1
if cur_substr is None:
is_bad = True
break
sub_tokens.append(cur_substr)
start = end
if is_bad:
output_tokens.append(self.unk_token)
else:
output_tokens.extend(sub_tokens)
return output_tokens
def _is_whitespace(char):
"""Checks whether `chars` is a whitespace character."""
# \t, \n, and \r are technically contorl characters but we treat them
# as whitespace since they are generally considered as such.
if char == " " or char == "\t" or char == "\n" or char == "\r":
return True
cat = unicodedata.category(char)
if cat == "Zs":
return True
return False
def _is_control(char):
"""Checks whether `chars` is a control character."""
# These are technically control characters but we count them as whitespace
# characters.
if char == "\t" or char == "\n" or char == "\r":
return False
cat = unicodedata.category(char)
if cat.startswith("C"):
return True
return False
def _is_punctuation(char):
"""Checks whether `chars` is a punctuation character."""
cp = ord(char)
# We treat all non-letter/number ASCII as punctuation.
# Characters such as "^", "$", and "`" are not in the Unicode
# Punctuation class but we treat them as punctuation anyways, for
# consistency.
if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or
(cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)):
return True
cat = unicodedata.category(char)
if cat.startswith("P"):
return True
return False
# coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .fp16util import (
BN_convert_float,
network_to_half,
prep_param_lists,
model_grads_to_master_grads,
master_params_to_model_params,
tofp16,
to_python_float,
clip_grad_norm,
convert_module,
convert_network,
FP16Model,
)
from .fp16 import *
from .loss_scaler import *
# coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Stable version of apex FP16 Optimizer"""
import torch
from torch import nn
from torch.autograd import Variable
from torch.nn.parameter import Parameter
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
from .loss_scaler import DynamicLossScaler, LossScaler
from .fp16util import model_grads_to_master_grads, master_params_to_model_params, clip_grad_norm
FLOAT_TYPES = (torch.FloatTensor, torch.cuda.FloatTensor)
HALF_TYPES = (torch.HalfTensor, torch.cuda.HalfTensor)
def conversion_helper(val, conversion):
"""Apply conversion to val. Recursively apply conversion if `val` is a nested tuple/list structure."""
if not isinstance(val, (tuple, list)):
return conversion(val)
rtn = [conversion_helper(v, conversion) for v in val]
if isinstance(val, tuple):
rtn = tuple(rtn)
return rtn
def fp32_to_fp16(val):
"""Convert fp32 `val` to fp16"""
def half_conversion(val):
val_typecheck = val
if isinstance(val_typecheck, (Parameter, Variable)):
val_typecheck = val.data
if isinstance(val_typecheck, FLOAT_TYPES):
val = val.half()
return val
return conversion_helper(val, half_conversion)
def fp16_to_fp32(val):
"""Convert fp16 `val` to fp32"""
def float_conversion(val):
val_typecheck = val
if isinstance(val_typecheck, (Parameter, Variable)):
val_typecheck = val.data
if isinstance(val_typecheck, HALF_TYPES):
val = val.float()
return val
return conversion_helper(val, float_conversion)
class FP16_Module(nn.Module):
def __init__(self, module):
super(FP16_Module, self).__init__()
self.add_module('module', module.half())
def forward(self, *inputs, **kwargs):
return fp16_to_fp32(self.module(*(fp32_to_fp16(inputs)), **kwargs))
def state_dict(self, destination=None, prefix='', keep_vars=False):
return self.module.state_dict(destination, prefix, keep_vars)
def load_state_dict(self, state_dict, strict=True):
self.module.load_state_dict(state_dict, strict=strict)
# TODO: Update overflow check + downscale to use Carl's fused kernel.
class FP16_Optimizer(object):
"""
:class:`FP16_Optimizer` is designed to wrap an existing PyTorch optimizer,
and manage static or dynamic loss scaling and master weights in a manner transparent to the user.
For standard use, only two lines must be changed: creating the :class:`FP16_Optimizer` instance,
and changing the call to ``backward``.
Example::
model = torch.nn.Linear(D_in, D_out).cuda().half()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
# Name the FP16_Optimizer instance to replace the existing optimizer
# (recommended but not required):
optimizer = FP16_Optimizer(optimizer, static_loss_scale = 128.0)
...
# loss.backward() becomes:
optimizer.backward(loss)
...
Example with dynamic loss scaling::
...
optimizer = FP16_Optimizer(optimizer, dynamic_loss_scale=True)
# optional arg to control dynamic loss scaling behavior
# dynamic_loss_args={'scale_window' : 500})
# Usually, dynamic_loss_args is not necessary.
Args:
init_optimizer (torch.optim.optimizer): Existing optimizer created with the parameters to optimize. Internally, :class:`FP16_Optimizer` replaces the passed optimizer's fp16 parameters, if any, with fp32 master parameters copied from the original ones. :class:`FP16_Optimizer` also stores references to the original fp16 parameters, and updates these fp16 parameters from the master fp32 copy at the end of each :attr:`step`.
static_loss_scale (float, optional, default=1.0): Loss scale used internally to scale gradients computed by the model. Any fp16 gradients will be copied to fp32, then downscaled before being applied to the fp32 master params, so ``static_loss_scale`` should not affect learning rate.
dynamic_loss_scale (bool, optional, default=False): Use dynamic loss scaling. If True, this will override any ``static_loss_scale`` option.
dynamic_loss_args (dict, optional, default=None): Dict of kwargs that will be forwarded to the internal :class:`DynamicLossScaler` instance's constructor. Keys of this dict must match kwargs accepted by :class:`DynamicLossScaler`'s constructor. If ``dynamic_loss_args`` is unspecified, :class:`DynamicLossScaler`'s defaults will be used.
verbose (bool, optional, default=True): By default, FP16_Optimizer's constructor prints out the parameters and parameter groups it is ingesting, as a sanity check. If this becomes annoying (e.g. for large models), it can be disabled by passing ``verbose=False``. ``verbose=False`` will not disable printing when the loss scale is readjusted during dynamic loss scaling.
``init_optimizer`` is expected to have been constructed in the ordinary way.
It is recommended (although not required) that the newly constructed :class:`FP16_Optimizer` instance be
named to replace ``init_optimizer``, for two reasons:
First, it means that references to the same name
later in the file will not have to change.
Second, :class:`FP16_Optimizer` reserves the right (as an implementation detail) to
modify ``init_optimizer``. If you do choose a unique name for the new
:class:`FP16_Optimizer` instance, you should only work with this new instance,
because the preexisting optimizer might no longer behave as expected.
``init_optimizer`` may be any Pytorch optimizer.
It may contain a mixture of fp16 and fp32 parameters organized into any number of
``param_groups`` with different hyperparameters. The :class:`FP16_Optimizer` constructor will
ingest these ``param_groups`` and remember them.
Calls to ::
loss.backward()
must be replaced with ::
optimizer.backward(loss)
because :class:`FP16_Optimizer` requires ownership of the backward pass to implement
loss scaling and copies to master gradients.
.. note::
Loss scaling, either static or dynamic, is orthogonal to learning rate, because gradients
are downscaled before being applied. This means that adjusting the loss scale, or using
dynamic loss scaling, should not require retuning the learning rate or any other
hyperparameters.
**Advanced options**
**Closures**: :class:`FP16_Optimizer` can wrap a Pytorch optimizer that receives a closure.
See docstring for :attr:`step`.
**Gradient clipping**: Use :attr:`clip_master_grads`.
**Multiple losses**: If your model accumulates gradients from multiple losses,
this can be made more efficient by supplying ``update_master_grads=False``
to :attr:`backward`. See docstring for :attr:`backward`.
**Manually adjusting loss scale**: The current loss scale can be retrieved or set via ::
print(optimizer.loss_scale)
optimizer.loss_scale = new_loss_scale
For static loss scaling, manually adjusting the loss scale over time is a reasonable
thing to do. During later epochs, gradients may become smaller, and a
higher loss scale may be required, analogous to scheduling the learning rate. Dynamic loss
scaling is more subtle (see :class:`DynamicLossScaler`) and in this case, manually adjusting
the loss scale is not recommended.
**Multi_GPU training**: If the wrapped ``init_optimizer`` was created from a model wrapped in
Pytorch DistributedDataParallel or Apex DistributedDataParallel, :class:`FP16_Optimizer`
should still work as intended.
"""
def __init__(self,
init_optimizer,
static_loss_scale=1.0,
dynamic_loss_scale=False,
dynamic_loss_args=None,
verbose=False):
if not torch.cuda.is_available:
raise SystemError("Cannot use fp16 without CUDA.")
self.verbose = verbose
self.optimizer = init_optimizer
# init_state_dict sets up an alternative way to cast per-param state tensors.
# Stashing here in case https://github.com/pytorch/pytorch/issues/7733 makes it necessary.
# init_state_dict = init_optimizer.state_dict()
self.fp16_groups = []
self.fp32_from_fp16_groups = []
self.fp32_from_fp32_groups = []
for i, param_group in enumerate(self.optimizer.param_groups):
self.maybe_print("FP16_Optimizer processing param group {}:".format(i))
fp16_params_this_group = []
fp32_params_this_group = []
fp32_from_fp16_params_this_group = []
for i, param in enumerate(param_group['params']):
if param.requires_grad:
if param.type() == 'torch.cuda.HalfTensor':
self.maybe_print("FP16_Optimizer received torch.cuda.HalfTensor with {}"
.format(param.size()))
fp16_params_this_group.append(param)
master_param = param.detach().clone().float()
master_param.requires_grad = True
# Copythe model parallel flag.
master_param.model_parallel = param.model_parallel
param_group['params'][i] = master_param
fp32_from_fp16_params_this_group.append(master_param)
# Reset existing state dict key to the new master param.
# We still need to recast per-param state tensors, if any, to FP32.
if param in self.optimizer.state:
self.optimizer.state[master_param] = self.optimizer.state.pop(param)
elif param.type() == 'torch.cuda.FloatTensor':
self.maybe_print("FP16_Optimizer received torch.cuda.FloatTensor with {}"
.format(param.size()))
fp32_params_this_group.append(param)
param_group['params'][i] = param
else:
raise TypeError("Wrapped parameters must be either "
"torch.cuda.FloatTensor or torch.cuda.HalfTensor. "
"Received {}".format(param.type()))
self.fp16_groups.append(fp16_params_this_group)
self.fp32_from_fp16_groups.append(fp32_from_fp16_params_this_group)
self.fp32_from_fp32_groups.append(fp32_params_this_group)
# Leverage state_dict() and load_state_dict() to recast preexisting per-param state tensors
self.optimizer.load_state_dict(self.optimizer.state_dict())
# alternative way to cast per-param state tensors:
# self.optimizer.load_state_dict(init_state_dict)
if dynamic_loss_scale:
self.dynamic_loss_scale = True
if dynamic_loss_args is not None:
self.loss_scaler = DynamicLossScaler(**dynamic_loss_args)
else:
self.loss_scaler = DynamicLossScaler()
else:
self.dynamic_loss_scale = False
self.loss_scaler = LossScaler(static_loss_scale)
self.overflow = False
self.first_closure_call_this_step = True
self.clip_grad_norm = clip_grad_norm
def maybe_print(self, msg):
if self.verbose:
print(msg)
def __getstate__(self):
raise RuntimeError("FP16_Optimizer should be serialized using state_dict().")
def __setstate__(self, state):
raise RuntimeError("FP16_Optimizer should be deserialized using load_state_dict().")
def zero_grad(self, set_grads_to_None=False):
"""
Zero fp32 and fp16 parameter grads.
"""
# In principle, only the .grad attributes of the model params need to be zeroed,
# because gradients are copied into the FP32 master params. However, we zero
# all gradients owned by the optimizer, just to be safe:
for group in self.optimizer.param_groups:
for p in group['params']:
if set_grads_to_None:
p.grad = None
else:
if p.grad is not None:
p.grad.detach_()
p.grad.zero_()
# Zero fp16 gradients owned by the model:
for fp16_group in self.fp16_groups:
for param in fp16_group:
if set_grads_to_None:
param.grad = None
else:
if param.grad is not None:
param.grad.detach_() # as in torch.optim.optimizer.zero_grad()
param.grad.zero_()
def _check_overflow(self):
params = []
for group in self.fp16_groups:
for param in group:
params.append(param)
for group in self.fp32_from_fp32_groups:
for param in group:
params.append(param)
self.overflow = self.loss_scaler.has_overflow(params)
def _update_scale(self, has_overflow=False):
self.loss_scaler.update_scale(has_overflow)
def _master_params_to_model_params(self):
for fp16_group, fp32_from_fp16_group in zip(self.fp16_groups, self.fp32_from_fp16_groups):
master_params_to_model_params(fp16_group, fp32_from_fp16_group)
def _model_params_to_master_params(self):
for fp16_group, fp32_from_fp16_group in zip(self.fp16_groups, self.fp32_from_fp16_groups):
master_params_to_model_params(fp32_from_fp16_group, fp16_group)
# To consider: Integrate distributed with this wrapper by registering a hook on each variable
# that does the overflow check, gradient copy + downscale, and fp32 allreduce in a different stream.
def _model_grads_to_master_grads(self):
for fp16_group, fp32_from_fp16_group in zip(self.fp16_groups, self.fp32_from_fp16_groups):
model_grads_to_master_grads(fp16_group, fp32_from_fp16_group)
def _downscale_master(self):
if self.loss_scale != 1.0:
for group in self.optimizer.param_groups:
for param in group['params']:
if param.grad is not None:
param.grad.data.mul_(1./self.loss_scale)
def clip_master_grads(self, max_norm, norm_type=2):
"""
Clips fp32 master gradients via ``torch.nn.utils.clip_grad_norm``.
Args:
max_norm (float or int): max norm of the gradients
norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for
infinity norm.
Returns:
Total norm of the current fp32 gradients (viewed as a single vector).
.. warning::
Returns -1 if the most recently computed fp16 gradients overflowed (that is, if ``self.overflow`` is ``True``).
"""
if not self.overflow:
fp32_params = []
for param_group in self.optimizer.param_groups:
for param in param_group['params']:
fp32_params.append(param)
return self.clip_grad_norm(fp32_params, max_norm, norm_type)
else:
return -1
def state_dict(self):
"""
Returns a dict containing the current state of this :class:`FP16_Optimizer` instance.
This dict contains attributes of :class:`FP16_Optimizer`, as well as the state_dict
of the contained Pytorch optimizer.
Example::
checkpoint = {}
checkpoint['model'] = model.state_dict()
checkpoint['optimizer'] = optimizer.state_dict()
torch.save(checkpoint, "saved.pth")
"""
state_dict = {}
state_dict['loss_scaler'] = self.loss_scaler
state_dict['dynamic_loss_scale'] = self.dynamic_loss_scale
state_dict['overflow'] = self.overflow
state_dict['first_closure_call_this_step'] = self.first_closure_call_this_step
state_dict['optimizer_state_dict'] = self.optimizer.state_dict()
state_dict['fp32_from_fp16'] = self.fp32_from_fp16_groups
return state_dict
def load_state_dict(self, state_dict):
"""
Loads a state_dict created by an earlier call to state_dict().
If ``fp16_optimizer_instance`` was constructed from some ``init_optimizer``,
whose parameters in turn came from ``model``, it is expected that the user
will call ``model.load_state_dict()`` before
``fp16_optimizer_instance.load_state_dict()`` is called.
Example::
model = torch.nn.Linear(D_in, D_out).cuda().half()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
optimizer = FP16_Optimizer(optimizer, static_loss_scale = 128.0)
...
checkpoint = torch.load("saved.pth")
model.load_state_dict(checkpoint['model'])
optimizer.load_state_dict(checkpoint['optimizer'])
"""
# I think it should actually be ok to reload the optimizer before the model.
self.loss_scaler = state_dict['loss_scaler']
self.dynamic_loss_scale = state_dict['dynamic_loss_scale']
self.overflow = state_dict['overflow']
self.first_closure_call_this_step = state_dict['first_closure_call_this_step']
self.optimizer.load_state_dict(state_dict['optimizer_state_dict'])
# At this point, the optimizer's references to the model's fp32 parameters are up to date.
# The optimizer's hyperparameters and internal buffers are also up to date.
# However, the fp32 master copies of the model's fp16 params stored by the optimizer are still
# out of date. There are two options.
# 1: Refresh the master params from the model's fp16 params.
# This requires less storage but incurs precision loss.
# 2: Save and restore the fp32 master copies separately.
# We choose option 2.
#
# Pytorch Optimizer.load_state_dict casts saved buffers (e.g. momentum) to the type and device
# of their associated parameters, because it's possible those buffers might not exist yet in
# the current optimizer instance. In our case, as long as the current FP16_Optimizer has been
# constructed in the same way as the one whose state_dict we are loading, the same master params
# are guaranteed to exist, so we can just copy_() from the saved master params.
for current_group, saved_group in zip(self.fp32_from_fp16_groups, state_dict['fp32_from_fp16']):
for current, saved in zip(current_group, saved_group):
current.data.copy_(saved.data)
def step(self, closure=None): # could add clip option.
"""
If no closure is supplied, :attr:`step` should be called after
``fp16_optimizer_obj.backward(loss)``.
:attr:`step` updates the fp32 master copy of parameters using the optimizer supplied to
:class:`FP16_Optimizer`'s constructor, then copies the updated fp32 params into the fp16 params
originally referenced by :class:`FP16_Optimizer`'s constructor, so the user may immediately run
another forward pass using their model.
If a closure is supplied, :attr:`step` may be called without a prior call to
:attr:`backward(loss)`.
This control flow is identical to `ordinary Pytorch optimizer use`_ with closures.
However, the user should take care that any ``loss.backward()`` call within the closure
has been replaced by ``fp16_optimizer_obj.backward(loss)``.
Args:
closure (optional): Closure that will be supplied to the underlying optimizer originally passed to :class:`FP16_Optimizer`'s constructor. closure should call :attr:`zero_grad()` on the :class:`FP16_Optimizer` object, compute the loss, call :attr:`backward(loss)`, and return the loss.
Example with closure::
# optimizer is assumed to be an FP16_Optimizer object, previously constructed from an
# existing pytorch optimizer.
for input, target in dataset:
def closure():
optimizer.zero_grad()
output = model(input)
loss = loss_fn(output, target)
# loss.backward() becomes:
optimizer.backward(loss)
return loss
optimizer.step(closure)
.. warning::
Currently, calling :attr:`step` with a closure is not compatible with dynamic loss scaling.
.. _`ordinary Pytorch optimizer use`:
http://pytorch.org/docs/master/optim.html#optimizer-step-closure
"""
scale = self.loss_scaler.loss_scale
self._update_scale(self.overflow)
if self.overflow:
self.maybe_print("OVERFLOW! Skipping step. Attempted loss scale: {}, reducing to {}"
.format(scale, self.loss_scale))
return
if closure is not None:
retval = self._step_with_closure(closure)
else:
retval = self.optimizer.step()
self._master_params_to_model_params()
return retval
def _step_with_closure(self, closure):
def wrapped_closure():
# helpful for debugging
# print("Calling wrapped_closure, first_closure_call_this_step = {}"
# .format(self.first_closure_call_this_step))
if self.first_closure_call_this_step:
# We expect that the fp16 params are initially fresh on entering self.step(),
# so _master_params_to_model_params() is unnecessary the first time wrapped_closure()
# is called within self.optimizer.step().
self.first_closure_call_this_step = False
else:
# If self.optimizer.step() internally calls wrapped_closure more than once,
# it may update the fp32 params after each call. However, self.optimizer
# doesn't know about the fp16 params at all. If the fp32 params get updated,
# we can't rely on self.optimizer to refresh the fp16 params. We need
# to handle that manually:
self._master_params_to_model_params()
# Our API expects the user to give us ownership of the backward() call by
# replacing all calls to loss.backward() with optimizer.backward(loss).
# This requirement holds whether or not the call to backward() is made within a closure.
# If the user is properly calling optimizer.backward(loss) within "closure,"
# calling closure() here will give the fp32 master params fresh gradients
# for the optimizer to play with, so all wrapped_closure needs to do is call
# closure() and return the loss.
temp_loss = closure()
while(self.overflow):
scale = self.loss_scaler.loss_scale
self._update_scale(self.overflow)
self.maybe_print("OVERFLOW within closure! Skipping step. Attempted loss scale: {}, "
"reducing to {}".format(scale, self.loss_scale))
temp_loss = closure()
return temp_loss
retval = self.optimizer.step(wrapped_closure)
self.first_closure_call_this_step = True
return retval
def backward(self, loss, update_master_grads=True, retain_graph=False):
"""
:attr:`backward` performs the following conceptual steps:
1. fp32_loss = loss.float() (see first Note below)
2. scaled_loss = fp32_loss*loss_scale
3. scaled_loss.backward(), which accumulates scaled gradients into the ``.grad`` attributes of the model's leaves (which may be fp16, fp32, or a mixture, depending how your model was defined).
4. fp16 grads are then copied to the master params' ``.grad`` attributes (see second Note), which are guaranteed to be fp32.
5. Finally, master grads are divided by loss_scale.
In this way, after :attr:`backward`, the master params have fresh gradients,
and :attr:`step` may be called.
.. note::
:attr:`backward` internally converts the loss to fp32 before applying the loss scale.
This provides some additional safety against overflow if the user has supplied an
fp16 loss value.
However, for maximum overflow safety, the user should
compute the loss criterion (MSE, cross entropy, etc) in fp32 before supplying it to
:attr:`backward`.
.. warning::
The gradients found in a model's leaves after the call to
:attr:`backward` should not be regarded as valid in general,
because it's possible
they have been scaled (and in the case of dynamic loss scaling,
the scale factor may change over time).
If the user wants to inspect gradients after a call to :attr:`backward`,
only the master gradients should be regarded as valid. These can be retrieved via
:attr:`inspect_master_grad_data()`.
Args:
loss: The loss output by the user's model. loss may be either float or half (but see first Note above).
update_master_grads (bool, optional, default=True): Option to copy fp16 grads to fp32 grads on this call. By setting this to False, the user can delay the copy, which is useful to eliminate redundant fp16->fp32 grad copies if :attr:`backward` is being called on multiple losses in one iteration. If set to False, the user becomes responsible for calling :attr:`update_master_grads` before calling :attr:`step`.
retain_graph (bool, optional, default=False): Forwards the usual ``retain_graph=True`` option to the internal call to ``loss.backward``. If ``retain_graph`` is being used to accumulate gradient values from multiple backward passes before calling ``optimizer.step``, passing ``update_master_grads=False`` is also recommended (see Example below).
Example::
# Ordinary operation:
optimizer.backward(loss)
# Naive operation with multiple losses (technically valid, but less efficient):
# fp32 grads will be correct after the second call, but
# the first call incurs an unnecessary fp16->fp32 grad copy.
optimizer.backward(loss1)
optimizer.backward(loss2)
# More efficient way to handle multiple losses:
# The fp16->fp32 grad copy is delayed until fp16 grads from all
# losses have been accumulated.
optimizer.backward(loss1, update_master_grads=False)
optimizer.backward(loss2, update_master_grads=False)
optimizer.update_master_grads()
"""
# To consider: try multiple backward passes using retain_grad=True to find
# a loss scale that works. After you find a loss scale that works, do a final dummy
# backward pass with retain_graph=False to tear down the graph. Doing this would avoid
# discarding the iteration, but probably wouldn't improve overall efficiency.
self.loss_scaler.backward(loss.float(), retain_graph=retain_graph)
if update_master_grads:
self.update_master_grads()
def update_master_grads(self):
"""
Copy the ``.grad`` attribute from stored references to fp16 parameters to
the ``.grad`` attribute of the fp32 master parameters that are directly
updated by the optimizer. :attr:`update_master_grads` only needs to be called if
``fp16_optimizer_obj.backward`` was called with ``update_master_grads=False``.
"""
if self.dynamic_loss_scale:
self._check_overflow()
if self.overflow: return
self._model_grads_to_master_grads()
self._downscale_master()
def inspect_master_grad_data(self):
"""
When running with :class:`FP16_Optimizer`,
``.grad`` attributes of a model's fp16 leaves should not be
regarded as truthful, because they might be scaled.
After a call to :attr:`fp16_optimizer_obj.backward(loss)`, if no overflow was encountered,
the fp32 master params' ``.grad``
attributes will contain valid gradients properly divided by the loss scale. However,
because :class:`FP16_Optimizer` flattens some parameters, accessing them may be
nonintuitive. :attr:`inspect_master_grad_data`
allows those gradients to be viewed with shapes corresponding to their associated model leaves.
Returns:
List of lists (one list for each parameter group). The list for each parameter group
is a list of the ``.grad.data`` attributes of the fp32 master params belonging to that group.
"""
if self.overflow:
print("Warning: calling FP16_Optimizer.inspect_master_grad_data while in an overflow state. "
"Gradients are currently invalid (may be inf, nan, or stale). Returning None.")
return None
else:
# The optimizer owns only references to master params.
master_grads_data = []
for param_group in self.optimizer.param_groups:
master_grads_this_group = []
for param in param_group['params']:
if param.grad is not None:
master_grads_this_group.append(param.grad.data)
else:
master_grads_this_group.append(None)
master_grads_data.append(master_grads_this_group)
return master_grads_data
# Promote loss scale so it can be retrieved or set via "fp16_optimizer_instance.loss_scale"
def _get_loss_scale(self):
return self.loss_scaler.loss_scale
def _set_loss_scale(self, value):
self.loss_scaler.cur_scale = value
loss_scale = property(_get_loss_scale, _set_loss_scale)
# Promote state so it can be retrieved or set via "fp16_optimizer_instance.state"
def _get_state(self):
return self.optimizer.state
def _set_state(self, value):
self.optimizer.state = value
state = property(_get_state, _set_state)
# Promote param_groups so it can be retrieved or set via "fp16_optimizer_instance.param_groups"
# (for example, to adjust the learning rate)
def _get_param_groups(self):
return self.optimizer.param_groups
def _set_param_groups(self, value):
self.optimizer.param_groups = value
param_groups = property(_get_param_groups, _set_param_groups)
# coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.nn as nn
from torch.autograd import Variable
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
import mpu
class tofp16(nn.Module):
"""
Utility module that implements::
def forward(self, input):
return input.half()
"""
def __init__(self):
super(tofp16, self).__init__()
def forward(self, input):
return input.half()
def BN_convert_float(module):
"""
Utility function for network_to_half().
Retained for legacy purposes.
"""
if isinstance(module, torch.nn.modules.batchnorm._BatchNorm) and module.affine is True:
module.float()
for child in module.children():
BN_convert_float(child)
return module
def network_to_half(network):
"""
Convert model to half precision in a batchnorm-safe way.
Retained for legacy purposes. It is recommended to use FP16Model.
"""
return nn.Sequential(tofp16(), BN_convert_float(network.half()))
def convert_module(module, dtype):
"""
Converts a module's immediate parameters and buffers to dtype.
"""
for param in module.parameters(recurse=False):
if param is not None:
if param.data.dtype.is_floating_point:
param.data = param.data.to(dtype=dtype)
if param._grad is not None and param._grad.data.dtype.is_floating_point:
param._grad.data = param._grad.data.to(dtype=dtype)
for buf in module.buffers(recurse=False):
if buf is not None and buf.data.dtype.is_floating_point:
buf.data = buf.data.to(dtype=dtype)
def convert_network(network, dtype):
"""
Converts a network's parameters and buffers to dtype.
"""
for module in network.modules():
if isinstance(module, torch.nn.modules.batchnorm._BatchNorm) and module.affine is True:
continue
convert_module(module, dtype)
return network
class FP16Model(nn.Module):
"""
Convert model to half precision in a batchnorm-safe way.
"""
def __init__(self, network):
super(FP16Model, self).__init__()
self.network = convert_network(network, dtype=torch.half)
def forward(self, *inputs):
inputs = tuple(t.half() for t in inputs)
return self.network(*inputs)
def backwards_debug_hook(grad):
raise RuntimeError("master_params recieved a gradient in the backward pass!")
def prep_param_lists(model, flat_master=False):
"""
Creates a list of FP32 master parameters for a given model, as in
`Training Neural Networks with Mixed Precision: Real Examples`_.
Args:
model (torch.nn.Module): Existing Pytorch model
flat_master (bool, optional, default=False): Flatten the master parameters into a single tensor, as a performance optimization.
Returns:
A tuple (``model_params``, ``master_params``). ``model_params`` is a list of the model's parameters for later use with :func:`model_grads_to_master_grads` and :func:`master_params_to_model_params`. ``master_params`` is a list of FP32 master gradients. If ``flat_master=True``, ``master_params`` will be a list with one element.
Example::
model_params, master_params = prep_param_lists(model)
.. warning::
Currently, if ``flat_master=True``, all the model's parameters must be the same type. If the model has parameters of different types, use ``flat_master=False``, or use :class:`FP16_Optimizer`.
.. _`Training Neural Networks with Mixed Precision: Real Examples`:
http://on-demand.gputechconf.com/gtc/2018/video/S81012/
"""
model_params = [param for param in model.parameters() if param.requires_grad]
if flat_master:
# Give the user some more useful error messages
try:
# flatten_dense_tensors returns a contiguous flat array.
# http://pytorch.org/docs/master/_modules/torch/_utils.html
master_params = _flatten_dense_tensors([param.data for param in model_params]).float()
except:
print("Error in prep_param_lists: model may contain a mixture of parameters "
"of different types. Use flat_master=False, or use F16_Optimizer.")
raise
master_params = torch.nn.Parameter(master_params)
master_params.requires_grad = True
# master_params.register_hook(backwards_debug_hook)
if master_params.grad is None:
master_params.grad = master_params.new(*master_params.size())
return model_params, [master_params]
else:
master_params = [param.clone().float().detach() for param in model_params]
for param in master_params:
param.requires_grad = True
return model_params, master_params
def model_grads_to_master_grads(model_params, master_params, flat_master=False):
"""
Copy model gradients to master gradients.
Args:
model_params: List of model parameters created by :func:`prep_param_lists`.
master_params: List of FP32 master parameters created by :func:`prep_param_lists`. If ``master_params`` was created with ``flat_master=True``, ``flat_master=True`` should also be supplied to :func:`model_grads_to_master_grads`.
"""
if flat_master:
# The flattening may incur one more deep copy than is necessary.
master_params[0].grad.data.copy_(
_flatten_dense_tensors([p.grad.data for p in model_params]))
else:
for model, master in zip(model_params, master_params):
if model.grad is not None:
if master.grad is None:
master.grad = Variable(master.data.new(*master.data.size()))
master.grad.data.copy_(model.grad.data)
else:
master.grad = None
def master_params_to_model_params(model_params, master_params, flat_master=False):
"""
Copy master parameters to model parameters.
Args:
model_params: List of model parameters created by :func:`prep_param_lists`.
master_params: List of FP32 master parameters created by :func:`prep_param_lists`. If ``master_params`` was created with ``flat_master=True``, ``flat_master=True`` should also be supplied to :func:`master_params_to_model_params`.
"""
if flat_master:
for model, master in zip(model_params,
_unflatten_dense_tensors(master_params[0].data, model_params)):
model.data.copy_(master)
else:
for model, master in zip(model_params, master_params):
model.data.copy_(master.data)
# Backward compatibility fixes
def to_python_float(t):
if hasattr(t, 'item'):
return t.item()
else:
return t[0]
TORCH_MAJOR = int(torch.__version__.split('.')[0])
TORCH_MINOR = int(torch.__version__.split('.')[1])
clip_grad_norm = mpu.clip_grad_norm
#elif TORCH_MAJOR == 0 and TORCH_MINOR <= 4:
# clip_grad_norm = torch.nn.utils.clip_grad_norm
#else:
# clip_grad_norm = torch.nn.utils.clip_grad_norm_
# coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import mpu
# item() is a recent addition, so this helps with backward compatibility.
def to_python_float(t):
if hasattr(t, 'item'):
return t.item()
else:
return t[0]
class LossScaler:
"""
Class that manages a static loss scale. This class is intended to interact with
:class:`FP16_Optimizer`, and should not be directly manipulated by the user.
Use of :class:`LossScaler` is enabled via the ``static_loss_scale`` argument to
:class:`FP16_Optimizer`'s constructor.
Args:
scale (float, optional, default=1.0): The loss scale.
"""
def __init__(self, scale=1):
self.cur_scale = scale
# `params` is a list / generator of torch.Variable
def has_overflow(self, params):
return False
# `x` is a torch.Tensor
def _has_inf_or_nan(x):
return False
def update_scale(self, overflow):
pass
@property
def loss_scale(self):
return self.cur_scale
def scale_gradient(self, module, grad_in, grad_out):
return tuple(self.loss_scale * g for g in grad_in)
def backward(self, loss, retain_graph=False):
scaled_loss = loss*self.loss_scale
scaled_loss.backward(retain_graph=retain_graph)
class DynamicLossScaler:
"""
Class that manages dynamic loss scaling. It is recommended to use :class:`DynamicLossScaler`
indirectly, by supplying ``dynamic_loss_scale=True`` to the constructor of
:class:`FP16_Optimizer`. However, it's important to understand how :class:`DynamicLossScaler`
operates, because the default options can be changed using the
the ``dynamic_loss_args`` argument to :class:`FP16_Optimizer`'s constructor.
Loss scaling is designed to combat the problem of underflowing gradients encountered at long
times when training fp16 networks. Dynamic loss scaling begins by attempting a very high loss
scale. Ironically, this may result in OVERflowing gradients. If overflowing gradients are
encountered, :class:`DynamicLossScaler` informs :class:`FP16_Optimizer` that an overflow has
occurred.
:class:`FP16_Optimizer` then skips the update step for this particular iteration/minibatch,
and :class:`DynamicLossScaler` adjusts the loss scale to a lower value.
If a certain number of iterations occur without overflowing gradients detected,
:class:`DynamicLossScaler` increases the loss scale once more.
In this way :class:`DynamicLossScaler` attempts to "ride the edge" of
always using the highest loss scale possible without incurring overflow.
Args:
init_scale (float, optional, default=2**32): Initial loss scale attempted by :class:`DynamicLossScaler.`
scale_factor (float, optional, default=2.0): Factor used when adjusting the loss scale. If an overflow is encountered, the loss scale is readjusted to loss scale/``scale_factor``. If ``scale_window`` consecutive iterations take place without an overflow, the loss scale is readjusted to loss_scale*``scale_factor``.
scale_window (int, optional, default=1000): Number of consecutive iterations without an overflow to wait before increasing the loss scale.
"""
def __init__(self,
init_scale=2**32,
scale_factor=2.,
scale_window=1000,
min_scale=1,
delayed_shift=1,
consecutive_hysteresis=False):
self.cur_scale = init_scale
self.cur_iter = 0
self.last_overflow_iter = -1
self.scale_factor = scale_factor
self.scale_window = scale_window
self.min_scale = min_scale
self.delayed_shift = delayed_shift
self.cur_hysteresis = delayed_shift
self.consecutive_hysteresis = consecutive_hysteresis
# `params` is a list / generator of torch.Variable
def has_overflow_serial(self, params):
for p in params:
if p.grad is not None and DynamicLossScaler._has_inf_or_nan(p.grad.data):
return True
return False
def has_overflow(self, params):
overflow = self.has_overflow_serial(params)
# Since each model parallel GPU carries only part of the model,
# make sure overflow flag is synced across all the model parallel GPUs
overflow_gpu = torch.cuda.ByteTensor([overflow])
torch.distributed.all_reduce(overflow_gpu,
op=torch.distributed.ReduceOp.MAX,
group=mpu.get_model_parallel_group())
overflow = overflow_gpu[0].item()
return bool(overflow)
# `x` is a torch.Tensor
def _has_inf_or_nan(x):
try:
# if x is half, the .float() incurs an additional deep copy, but it's necessary if
# Pytorch's .sum() creates a one-element tensor of the same type as x
# (which is true for some recent version of pytorch).
cpu_sum = float(x.float().sum())
# More efficient version that can be used if .sum() returns a Python scalar
# cpu_sum = float(x.sum())
except RuntimeError as instance:
# We want to check if inst is actually an overflow exception.
# RuntimeError could come from a different error.
# If so, we still want the exception to propagate.
if "value cannot be converted" not in instance.args[0]:
raise
return True
else:
if cpu_sum == float('inf') or cpu_sum == -float('inf') or cpu_sum != cpu_sum:
return True
return False
# `overflow` is boolean indicating whether the gradient overflowed
def update_scale(self, overflow):
if not hasattr(self, 'min_scale'):
self.min_scale = 1
if not hasattr(self, 'delayed_shift'):
self.delayed_shift = 1
if not hasattr(self, 'cur_hysteresis'):
self.cur_hysteresis = 1
if not hasattr(self, 'consecutive_hysteresis'):
self.consecutive_hysteresis = True
if overflow:
# self.cur_scale /= self.scale_factor
if self.delayed_shift == 1 or self.cur_hysteresis == 1:
self.cur_scale = max(self.cur_scale/self.scale_factor, self.min_scale)
else:
self.cur_hysteresis -= 1
self.last_overflow_iter = self.cur_iter
else:
if self.consecutive_hysteresis:
self.cur_hysteresis = self.delayed_shift
if (self.cur_iter - self.last_overflow_iter) % self.scale_window == 0:
if not self.consecutive_hysteresis:
self.cur_hysteresis = self.delayed_shift
self.cur_scale *= self.scale_factor
self.cur_iter += 1
@property
def loss_scale(self):
return self.cur_scale
def scale_gradient(self, module, grad_in, grad_out):
return tuple(self.loss_scale * g for g in grad_in)
def backward(self, loss, retain_graph=False):
scaled_loss = loss*self.loss_scale
scaled_loss.backward(retain_graph=retain_graph)
##############################################################
# Example usage below here -- assuming it's in a separate file
##############################################################
"""
TO-DO separate out into an example.
if __name__ == "__main__":
import torch
from torch.autograd import Variable
from dynamic_loss_scaler import DynamicLossScaler
# N is batch size; D_in is input dimension;
# H is hidden dimension; D_out is output dimension.
N, D_in, H, D_out = 64, 1000, 100, 10
# Create random Tensors to hold inputs and outputs, and wrap them in Variables.
x = Variable(torch.randn(N, D_in), requires_grad=False)
y = Variable(torch.randn(N, D_out), requires_grad=False)
w1 = Variable(torch.randn(D_in, H), requires_grad=True)
w2 = Variable(torch.randn(H, D_out), requires_grad=True)
parameters = [w1, w2]
learning_rate = 1e-6
optimizer = torch.optim.SGD(parameters, lr=learning_rate)
loss_scaler = DynamicLossScaler()
for t in range(500):
y_pred = x.mm(w1).clamp(min=0).mm(w2)
loss = (y_pred - y).pow(2).sum() * loss_scaler.loss_scale
print('Iter {} loss scale: {}'.format(t, loss_scaler.loss_scale))
print('Iter {} scaled loss: {}'.format(t, loss.data[0]))
print('Iter {} unscaled loss: {}'.format(t, loss.data[0] / loss_scaler.loss_scale))
# Run backprop
optimizer.zero_grad()
loss.backward()
# Check for overflow
has_overflow = DynamicLossScaler.has_overflow(parameters)
# If no overflow, unscale grad and update as usual
if not has_overflow:
for param in parameters:
param.grad.data.mul_(1. / loss_scaler.loss_scale)
optimizer.step()
# Otherwise, don't do anything -- ie, skip iteration
else:
print('OVERFLOW!')
# Update loss scale for next iteration
loss_scaler.update_scale(has_overflow)
"""
# coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Sample Generate GPT2"""
import os
import nltk
import random
import numpy as np
import torch
import torch.nn.functional as F
import argparse
import time
from datetime import datetime
from arguments import get_args
from utils import Timers, set_random_seed
from utils import load_checkpoint, get_checkpoint_iteration
from data_utils import make_tokenizer
from configure_data import configure_data
import mpu
from pretrain_gpt2 import set_deepspeed_activation_checkpointing
from fp16 import FP16_Module
from model import GPT2Model
from utils import print_rank_0
import deepspeed
USE_TORCH_DDP = True
import pickle
def get_model(args):
"""Build the model."""
print_rank_0('building GPT2 model ...')
model = GPT2Model(num_layers=args.num_layers,
vocab_size=args.vocab_size,
hidden_size=args.hidden_size,
num_attention_heads=args.num_attention_heads,
embedding_dropout_prob=args.hidden_dropout,
attention_dropout_prob=args.attention_dropout,
output_dropout_prob=args.hidden_dropout,
max_sequence_length=args.max_position_embeddings,
max_memory_length=args.mem_length,
checkpoint_activations=args.checkpoint_activations,
checkpoint_num_layers=args.checkpoint_num_layers,
parallel_output=False,
relative_encoding=args.transformer_xl)
if mpu.get_data_parallel_rank() == 0:
print(' > number of parameters on model parallel rank {}: {}'.format(
mpu.get_model_parallel_rank(),
sum([p.nelement() for p in model.parameters()])), flush=True)
# To prevent OOM for model sizes that cannot fit in GPU memory in full precision
if hasattr(args, "deepspeed") and args.deepspeed and args.fp16:
model.half()
# GPU allocation.
model.cuda(torch.cuda.current_device())
# Fp16 conversion.
if args.fp16:
model = FP16_Module(model)
# Wrap model for distributed training.
if USE_TORCH_DDP:
from model import PyTorchDistributedDataParallel as DDP
i = torch.cuda.current_device()
model = DDP(model, device_ids=[i], output_device=i,
process_group=mpu.get_data_parallel_group())
else:
from model import DistributedDataParallel as DDP
model = DDP(model)
return model
def get_masks_and_position_ids(data,
eod_token,
reset_position_ids,
reset_attention_mask,
loss_mask=None,
attention_mask=None,
transformer_xl=False,
mem_length=None):
# Extract batch size and sequence length.
batch_size, seq_length = data.size()
# Attention mask (lower triangular).
if transformer_xl:
if attention_mask is None:
attention_mask = torch.ones((1, seq_length, seq_length + mem_length), device=data.device)
attention_mask = torch.tril(torch.triu(attention_mask, 1 - seq_length + mem_length), mem_length)
else:
if reset_attention_mask:
att_mask_batch = batch_size
else:
att_mask_batch = 1
if attention_mask is None:
attention_mask = torch.ones((att_mask_batch, seq_length, seq_length), device=data.device)
attention_mask = torch.tril(attention_mask)
attention_mask = attention_mask.unsqueeze(1)
# Loss mask.
if loss_mask is None:
loss_mask = torch.ones(data.size(), dtype=torch.float, device=data.device)
# Position ids.
position_ids = torch.arange(seq_length, dtype=torch.long,
device=data.device)
position_ids = position_ids.unsqueeze(0).expand_as(data)
if not transformer_xl:
loss_mask[data == eod_token] = 0.0
# We need to clone as the ids will be modifed based on batch index.
if reset_position_ids:
position_ids = position_ids.clone()
if reset_position_ids or reset_attention_mask:
# Loop through the batches:
for b in range(batch_size):
# Find indecies where EOD token is.
eod_index = position_ids[b, data[b] == eod_token]
# Detach indecies from positions if going to modify positions.
if reset_position_ids:
eod_index = eod_index.clone()
# Loop through EOD indecies:
prev_index = 0
for j in range(eod_index.size()[0]):
i = eod_index[j]
# Mask attention loss.
if reset_attention_mask:
attention_mask[b, 0, (i + 1):, :(i + 1)] = 0
# Reset positions.
if reset_position_ids:
position_ids[b, (i + 1):] -= (i + 1 - prev_index)
prev_index = i + 1
return attention_mask, loss_mask, position_ids
def initialize_distributed(args):
"""Initialize torch.distributed."""
# Manually set the device ids.
print(f"args.rank: {args.rank}")
device = args.rank % torch.cuda.device_count()
if args.local_rank is not None:
device = args.local_rank
torch.cuda.set_device(device)
# Call the init process
init_method = 'tcp://'
master_ip = os.getenv('MASTER_ADDR', 'localhost')
master_port = args.MASTER_PORT
init_method += master_ip + ':' + master_port
print(f"master port : {master_port}, init_method: {init_method}")
torch.distributed.init_process_group(
backend=args.distributed_backend,
world_size=args.world_size, rank=args.rank,
init_method=init_method)
# Set the model-parallel / data-parallel communicators.
mpu.initialize_model_parallel(args.model_parallel_size)
# Optional DeepSpeed Activation Checkpointing Features
#
if hasattr(args, "deepspeed") and args.deepspeed and args.deepspeed_activation_checkpointing:
set_deepspeed_activation_checkpointing(args)
def setup_model(args):
"""Setup model and optimizer."""
model = get_model(args)
if args.deepspeed:
print_rank_0("DeepSpeed is enabled.")
model, _, _, _ = deepspeed.initialize(
model=model,
model_parameters=model.parameters(),
args=args,
mpu=mpu,
dist_init_required=False
)
if args.load is not None:
_ = load_checkpoint(
model, None, None, args, load_optimizer_states=False)
# if args.deepspeed:
# model = model.module
return model
def get_batch(context_tokens, device, args):
tokens = context_tokens
tokens = tokens.view(args.batch_size, -1).contiguous()
tokens = tokens.to(device)
# Get the masks and postition ids.
attention_mask, loss_mask, position_ids = get_masks_and_position_ids(
tokens,
args.eod_token,
reset_position_ids=False,
reset_attention_mask=False,
transformer_xl=args.transformer_xl,
mem_length=args.mem_length)
return tokens, attention_mask, position_ids
def top_k_logits(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')):
# This function has been mostly taken from huggingface conversational ai code at
# https://medium.com/huggingface/how-to-build-a-state-of-the-art-conversational-ai-with-transfer-learning-2d818ac26313
if top_k > 0:
# Remove all tokens with a probability less than the last token of the top-k
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
logits[indices_to_remove] = filter_value
if top_p > 0.0:
# convert to 1D
logits = logits.view(logits.size()[1]).contiguous()
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
# Remove tokens with cumulative probability above the threshold
sorted_indices_to_remove = cumulative_probs > top_p
# Shift the indices to the right to keep also the first token above the threshold
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0
indices_to_remove = sorted_indices[sorted_indices_to_remove]
logits[indices_to_remove] = filter_value
# going back to 2D
logits = logits.view(1, -1).contiguous()
return logits
def sample_sequence(model, tokenizer, context_tokens_tensor, context_length, args, device, mems=None, end_token=None):
tokens, attention_mask, position_ids = get_batch(context_tokens_tensor, device, args)
counter = 0
if mems is None:
mems = []
if end_token is None:
end_token = args.eod_token
org_context_length = context_length
while counter < (args.out_seq_length - org_context_length):
if counter == 0:
logits, *mems = model(tokens, position_ids, attention_mask, *mems)
else:
index = org_context_length + counter
logits, *mems = model(tokens[:, index - 1: index], tokens.new_ones((1, 1)) * (index - 1),
tokens.new_ones(1, 1, 1, args.mem_length + 1, device=tokens.device,
dtype=torch.float), *mems)
logits = logits[:, -1]
logits /= args.temperature
logits = top_k_logits(logits, top_k=args.top_k, top_p=args.top_p)
log_probs = F.softmax(logits, dim=-1)
prev = torch.multinomial(log_probs, num_samples=1)[0]
is_end = prev == end_token
if is_end:
break
tokens = torch.cat((tokens, prev.view(1, 1)), dim=1)
context_length += 1
counter += 1
if mpu.get_model_parallel_rank() == 0 and counter % 16 == 0:
output_tokens_list = tokens.view(-1).contiguous()
decode_tokens = tokenizer.DecodeIds(output_tokens_list.tolist())
if mpu.get_model_parallel_rank() == 0 and (counter % 128 == 0 or is_end):
os.system('clear')
trim_decode_tokens = decode_tokens
print(trim_decode_tokens, flush=True)
output_tokens_list = tokens.view(-1).contiguous()
return output_tokens_list, mems
def read_context(tokenizer, args, output, raw_text=None):
terminate_runs, skip_run = 0, 0
if mpu.get_model_parallel_rank() == 0:
while True:
if not raw_text:
raw_text = input("\nContext prompt (stop to exit) >>> ")
if not raw_text:
print('Prompt should not be empty!')
continue
if raw_text == "stop":
terminate_runs = 1
break
# output.write(raw_text)
context_tokens = tokenizer.EncodeAsIds(raw_text).tokenization
context_length = len(context_tokens)
if context_length >= args.seq_length:
print("\nContext length", context_length,
"\nPlease give smaller context than the window length!")
continue
break
else:
context_length = 0
terminate_runs_tensor = torch.cuda.LongTensor([terminate_runs])
torch.distributed.broadcast(terminate_runs_tensor, mpu.get_model_parallel_src_rank(),
group=mpu.get_model_parallel_group())
terminate_runs = terminate_runs_tensor[0].item()
if terminate_runs == 1:
return terminate_runs, None, None, None
context_length_tensor = torch.cuda.LongTensor([context_length])
torch.distributed.broadcast(context_length_tensor, mpu.get_model_parallel_src_rank(),
group=mpu.get_model_parallel_group())
context_length = context_length_tensor[0].item()
if mpu.get_model_parallel_rank() == 0:
context_tokens_tensor = torch.cuda.LongTensor(context_tokens)
else:
context_tokens_tensor = torch.cuda.LongTensor([0] * context_length)
torch.distributed.broadcast(context_tokens_tensor, mpu.get_model_parallel_src_rank(),
group=mpu.get_model_parallel_group())
if mpu.get_model_parallel_rank() != 0:
raw_text = tokenizer.DecodeIds(context_tokens_tensor.tolist())
return terminate_runs, raw_text, context_tokens_tensor, context_length
def generate_samples(model, tokenizer, args, device):
model.eval()
output_path = "./samples"
if not os.path.exists(output_path):
os.makedirs(output_path)
output_path = os.path.join(output_path, f"sample-{datetime.now().strftime('%m-%d-%H-%M')}.txt")
with torch.no_grad(), open(output_path, "w") as output:
input_file = None
if args.test_data_path is not None:
print_rank_0(args.test_data_path)
input_file = open(args.test_data_path)
while True:
torch.distributed.barrier(group=mpu.get_model_parallel_group())
raw_text = None
if input_file is not None:
raw_text = input_file.readline().strip()
if not raw_text:
return
terminate_runs, raw_text, context_tokens_tensor, context_length = read_context(tokenizer, args, output,
raw_text=raw_text)
if terminate_runs == 1:
return
start_time = time.time()
output_tokens_list, _ = sample_sequence(model, tokenizer, context_tokens_tensor, context_length, args,
device)
if mpu.get_model_parallel_rank() == 0:
os.system('clear')
print("\nTaken time {:.2f}\n".format(time.time() - start_time), flush=True)
print("\nContext:", raw_text, flush=True)
decode_tokens = tokenizer.DecodeIds(output_tokens_list.tolist())
trim_decode_tokens = decode_tokens[len(raw_text):]
print("\nGPT2:", trim_decode_tokens, flush=True)
output.write(trim_decode_tokens + "\n")
torch.distributed.barrier(group=mpu.get_model_parallel_group())
def prepare_tokenizer(args):
tokenizer_args = {
'tokenizer_type': args.tokenizer_type,
'corpus': None,
'model_path': args.tokenizer_path,
'vocab_size': args.vocab_size,
'model_type': args.tokenizer_model_type,
'cache_dir': args.cache_dir,
'add_eop': args.hierarchical}
tokenizer = make_tokenizer(**tokenizer_args)
num_tokens = tokenizer.num_tokens
before = num_tokens
after = before
multiple = args.make_vocab_size_divisible_by
while (after % multiple) != 0:
after += 1
print_rank_0('> padded vocab (size: {}) with {} dummy '
'tokens (new size: {})'.format(
before, after - before, after))
args.tokenizer_num_tokens = after
args.tokenizer_num_type_tokens = tokenizer.num_type_tokens
args.eod_token = tokenizer.get_command('eos').Id
# after = tokenizer.num_tokens
# while after % mpu.get_model_parallel_world_size() != 0:
# after += 1
args.vocab_size = after
print("prepare tokenizer done", flush=True)
return tokenizer
def main():
"""Main training program."""
print('Generate Samples')
# Disable CuDNN.
torch.backends.cudnn.enabled = False
# Arguments.
args = get_args()
args.deepspeed = False
args.mem_length = args.seq_length + args.mem_length - 1
# with open("./checkpoint_model/args.pkl", "wb") as file:
# pickle.dump(args, file=file)
# Pytorch distributed.
initialize_distributed(args)
# Random seeds for reproducability.
set_random_seed(args.seed)
# get the tokenizer
tokenizer = prepare_tokenizer(args)
# Model, optimizer, and learning rate.
model = setup_model(args)
# setting default batch size to 1
args.batch_size = 1
# generate samples
generate_samples(model, tokenizer, args, torch.cuda.current_device())
if __name__ == "__main__":
main()
# coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .distributed import *
from .gpt2_modeling import gpt2_get_params_for_weight_decay_optimization
from .gpt2_modeling import GPT2Model
from .model import BertModel
from .model import get_params_for_weight_decay_optimization
# coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
import torch.distributed as dist
from torch.nn.modules import Module
from torch.autograd import Variable
from torch.nn.parallel.distributed import DistributedDataParallel as DDP
import mpu
class PyTorchDistributedDataParallel(DDP):
def state_dict(self, destination=None, prefix='', keep_vars=False):
sd = self.module.state_dict(destination, prefix, keep_vars)
return sd
def load_state_dict(self, state_dict, strict=True):
self.module.load_state_dict(state_dict, strict=strict)
class DistributedDataParallel(Module):
def __init__(self, module):
super(DistributedDataParallel, self).__init__()
self.warn_on_half = True if dist._backend == dist.dist_backend.GLOO else False
self.module = module
self.data_parallel_group = mpu.get_data_parallel_group()
src_rank = mpu.get_model_parallel_rank()
for p in self.module.parameters():
if torch.is_tensor(p):
dist.broadcast(p, src_rank, group=self.data_parallel_group)
def allreduce_params(reduce_after=True, no_scale=False, fp32_allreduce=False):
if(self.needs_reduction):
self.needs_reduction = False
buckets = {}
for name, param in self.module.named_parameters():
if param.requires_grad and param.grad is not None:
tp = (param.data.type())
if tp not in buckets:
buckets[tp] = []
buckets[tp].append(param)
if self.warn_on_half:
if torch.cuda.HalfTensor in buckets:
print("WARNING: gloo dist backend for half parameters may be extremely slow." +
" It is recommended to use the NCCL backend in this case.")
self.warn_on_half = False
for tp in buckets:
bucket = buckets[tp]
grads = [param.grad.data for param in bucket]
coalesced = _flatten_dense_tensors(grads)
if fp32_allreduce:
coalesced = coalesced.float()
if not no_scale and not reduce_after:
coalesced /= dist.get_world_size(group=self.data_parallel_group)
dist.all_reduce(coalesced, group=self.data_parallel_group)
torch.cuda.synchronize()
if not no_scale and reduce_after:
coalesced /= dist.get_world_size(group=self.data_parallel_group)
for buf, synced in zip(grads, _unflatten_dense_tensors(coalesced, grads)):
buf.copy_(synced)
self.hook_handles = []
self.hooks = []
for param in list(self.module.parameters()):
def allreduce_hook(*unused):
Variable._execution_engine.queue_callback(allreduce_params)
# handle = param.register_hook(allreduce_hook)
#self.hooks.append(allreduce_hook)
#self.hook_handles.append(handle)
self.allreduce_params = allreduce_params
def forward(self, *inputs, **kwargs):
self.needs_reduction = True
return self.module(*inputs, **kwargs)
def state_dict(self, destination=None, prefix='', keep_vars=False):
#[h.remove() for h in self.hook_handles]
sd = self.module.state_dict(destination, prefix, keep_vars)
# for handle, hook in zip(self.hook_handles, self.hooks):
# d = handle.hooks_dict_ref()
# d[handle.id] = hook
return sd
def load_state_dict(self, state_dict, strict=True):
self.module.load_state_dict(state_dict, strict=strict)
'''
def _sync_buffers(self):
buffers = list(self.module._all_buffers())
if len(buffers) > 0:
# cross-node buffer sync
flat_buffers = _flatten_dense_tensors(buffers)
dist.broadcast(flat_buffers, 0)
for buf, synced in zip(buffers, _unflatten_dense_tensors(flat_buffers, buffers)):
buf.copy_(synced)
def train(self, mode=True):
# Clear NCCL communicator and CUDA event cache of the default group ID,
# These cache will be recreated at the later call. This is currently a
# work-around for a potential NCCL deadlock.
if dist._backend == dist.dist_backend.NCCL:
dist._clear_group_cache()
super(DistributedDataParallel, self).train(mode)
self.module.train(mode)
'''
# coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""GPT-2 model."""
import torch
import torch.nn.functional as F
import mpu
def init_method_normal(std=0.02):
"""Init method based on normal distribution.
This is only used for embeddings. The transformer has its
own initializer.
"""
def init_(tensor):
return torch.nn.init.normal_(tensor, mean=0.0, std=std)
return init_
class GPT2Model(torch.nn.Module):
"""GPT-2 Language model.
The output of the forward method are the logits (parallel or
serial depending on the `parallel_output` flag.
"""
def __init__(self,
num_layers,
vocab_size,
hidden_size,
num_attention_heads,
embedding_dropout_prob,
attention_dropout_prob,
output_dropout_prob,
max_sequence_length,
max_memory_length,
checkpoint_activations,
checkpoint_num_layers=1,
parallel_output=True,
relative_encoding=False):
super(GPT2Model, self).__init__()
self.parallel_output = parallel_output
init_method = init_method_normal(std=0.02)
# Word embeddings (parallel).
self.word_embeddings = mpu.VocabParallelEmbedding(
vocab_size, hidden_size, init_method=init_method)
# Transformer
self.transformer = mpu.GPT2ParallelTransformer(num_layers,
hidden_size,
num_attention_heads,
max_sequence_length,
max_memory_length,
embedding_dropout_prob,
attention_dropout_prob,
output_dropout_prob,
checkpoint_activations,
checkpoint_num_layers,
relative_encoding=relative_encoding)
def forward(self, input_ids, position_ids, attention_mask, *mems):
# Embeddings.
words_embeddings = self.word_embeddings(input_ids)
embeddings = words_embeddings
# Transformer.
transformer_output = self.transformer(embeddings, position_ids, attention_mask, *mems)
logits, *hidden_layers = transformer_output
# Parallel logits.
logits_parallel = mpu.copy_to_model_parallel_region(
logits)
logits_parallel = F.linear(logits_parallel,
self.word_embeddings.weight)
if self.parallel_output:
return (logits_parallel, *hidden_layers)
return (mpu.gather_from_model_parallel_region(logits_parallel), *hidden_layers)
def gpt2_get_params_for_weight_decay_optimization(module):
weight_decay_params = {'params': []}
no_weight_decay_params = {'params': [], 'weight_decay': 0.0}
for module_ in module.modules():
if isinstance(module_, (mpu.LayerNorm, torch.nn.LayerNorm)):
no_weight_decay_params['params'].extend(
[p for p in list(module_._parameters.values())
if p is not None])
else:
weight_decay_params['params'].extend(
[p for n, p in list(module_._parameters.items())
if p is not None and n != 'bias'])
no_weight_decay_params['params'].extend(
[p for n, p in list(module_._parameters.items())
if p is not None and n == 'bias'])
return weight_decay_params, no_weight_decay_params
# coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utilities for wrapping BertModel."""
import torch
from .modeling import BertConfig
from .modeling import BertForPreTraining, BertForMaskedLM
from .modeling import BertLayerNorm
def get_params_for_weight_decay_optimization(module):
weight_decay_params = {'params': []}
no_weight_decay_params = {'params': [], 'weight_decay': 0.0}
for module_ in module.modules():
if isinstance(module_, (BertLayerNorm, torch.nn.LayerNorm)):
no_weight_decay_params['params'].extend(
[p for p in list(module_._parameters.values())
if p is not None])
else:
weight_decay_params['params'].extend(
[p for n, p in list(module_._parameters.items())
if p is not None and n != 'bias'])
no_weight_decay_params['params'].extend(
[p for n, p in list(module_._parameters.items())
if p is not None and n == 'bias'])
return weight_decay_params, no_weight_decay_params
class BertModel(torch.nn.Module):
def __init__(self, args):
super(BertModel, self).__init__()
if args.pretrained_bert:
self.model = BertForPreTraining.from_pretrained(
args.tokenizer_model_type,
cache_dir=args.cache_dir,
fp32_layernorm=args.fp32_layernorm,
fp32_embedding=args.fp32_embedding,
layernorm_epsilon=args.layernorm_epsilon)
else:
if args.intermediate_size is None:
intermediate_size = 4 * args.hidden_size
else:
intermediate_size = args.intermediate_size
self.config = BertConfig(
args.tokenizer_num_tokens,
hidden_size=args.hidden_size,
num_hidden_layers=args.num_layers,
num_attention_heads=args.num_attention_heads,
intermediate_size=intermediate_size,
hidden_dropout_prob=args.hidden_dropout,
attention_probs_dropout_prob=args.attention_dropout,
max_position_embeddings=args.max_position_embeddings,
type_vocab_size=args.tokenizer_num_type_tokens,
fp32_layernorm=args.fp32_layernorm,
fp32_embedding=args.fp32_embedding,
fp32_tokentypes=args.fp32_tokentypes,
layernorm_epsilon=args.layernorm_epsilon,
deep_init=args.deep_init)
self.model = BertForPreTraining(self.config)
def forward(self, input_tokens, token_type_ids=None,
attention_mask=None, checkpoint_activations=False):
return self.model(
input_tokens, token_type_ids, attention_mask,
checkpoint_activations=checkpoint_activations)
def state_dict(self, destination=None, prefix='', keep_vars=False):
return self.model.state_dict(destination=destination, prefix=prefix,
keep_vars=keep_vars)
def load_state_dict(self, state_dict, strict=True):
return self.model.load_state_dict(state_dict, strict=strict)
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors and The HugginFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch BERT model."""
from __future__ import absolute_import, division, print_function, unicode_literals
import os
import copy
import json
import math
import logging
import tarfile
import tempfile
import shutil
import torch
from torch import nn
import torch.nn.functional as F
from torch.nn import CrossEntropyLoss
#from torch.utils.checkpoint import checkpoint
from data_utils.file_utils import cached_path
import mpu
def normal_init_method(mean, std):
def init_(tensor):
return torch.nn.init.normal_(tensor, mean=mean, std=std)
return init_
def scaled_init_method(mean, std, num_layers):
"""Init method based on N(0, sigma/sqrt(2*num_layers)."""
std = std / math.sqrt(2.0 * num_layers)
def init_(tensor):
return torch.nn.init.normal_(tensor, mean=mean, std=std)
return init_
logger = logging.getLogger(__name__)
PRETRAINED_MODEL_ARCHIVE_MAP = {
'bert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased.tar.gz",
'bert-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased.tar.gz",
'bert-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased.tar.gz",
'bert-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased.tar.gz",
'bert-base-multilingual-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased.tar.gz",
'bert-base-multilingual-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased.tar.gz",
'bert-base-chinese': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese.tar.gz",
}
CONFIG_NAME = 'bert_config.json'
WEIGHTS_NAME = 'pytorch_model.bin'
TF_WEIGHTS_NAME = 'model.ckpt'
def load_tf_weights_in_bert(model, tf_checkpoint_path):
""" Load tf checkpoints in a pytorch model
"""
try:
import re
import numpy as np
import tensorflow as tf
except ImportError:
print("Loading a TensorFlow models in PyTorch, requires TensorFlow to be installed. Please see "
"https://www.tensorflow.org/install/ for installation instructions.")
raise
tf_path = os.path.abspath(tf_checkpoint_path)
print("Converting TensorFlow checkpoint from {}".format(tf_path))
# Load weights from TF model
init_vars = tf.train.list_variables(tf_path)
names = []
arrays = []
for name, shape in init_vars:
print("Loading TF weight {} with shape {}".format(name, shape))
array = tf.train.load_variable(tf_path, name)
names.append(name)
arrays.append(array)
for name, array in zip(names, arrays):
name = name.split('/')
# adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
# which are not required for using pretrained model
if any(n in ["adam_v", "adam_m"] for n in name):
print("Skipping {}".format("/".join(name)))
continue
pointer = model
for m_name in name:
if re.fullmatch(r'[A-Za-z]+_\d+', m_name):
l = re.split(r'_(\d+)', m_name)
else:
l = [m_name]
if l[0] == 'kernel' or l[0] == 'gamma':
pointer = getattr(pointer, 'weight')
elif l[0] == 'output_bias' or l[0] == 'beta':
pointer = getattr(pointer, 'bias')
elif l[0] == 'output_weights':
pointer = getattr(pointer, 'weight')
else:
pointer = getattr(pointer, l[0])
if len(l) >= 2:
num = int(l[1])
pointer = pointer[num]
if m_name[-11:] == '_embeddings':
pointer = getattr(pointer, 'weight')
elif m_name == 'kernel':
array = np.transpose(array)
try:
assert pointer.shape == array.shape
except AssertionError as e:
e.args += (pointer.shape, array.shape)
raise
print("Initialize PyTorch weight {}".format(name))
pointer.data = torch.from_numpy(array)
return model
def gelu(x):
"""Implementation of the gelu activation function.
For information: OpenAI GPT's gelu is slightly different (and gives slightly different results):
0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
"""
return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))
def swish(x):
return x * torch.sigmoid(x)
ACT2FN = {"gelu": gelu, "relu": torch.nn.functional.relu, "swish": swish}
class BertConfig(object):
"""Configuration class to store the configuration of a `BertModel`.
"""
def __init__(self,
vocab_size_or_config_json_file,
hidden_size=768,
num_hidden_layers=12,
num_attention_heads=12,
intermediate_size=3072,
hidden_act="gelu",
hidden_dropout_prob=0.1,
attention_probs_dropout_prob=0.1,
max_position_embeddings=512,
type_vocab_size=2,
initializer_range=0.02,
deep_init=False,
fp32_layernorm=False,
fp32_embedding=False,
fp32_tokentypes=False,
layernorm_epsilon=1e-12):
"""Constructs BertConfig.
Args:
vocab_size_or_config_json_file: Vocabulary size of `inputs_ids` in `BertModel`.
hidden_size: Size of the encoder layers and the pooler layer.
num_hidden_layers: Number of hidden layers in the Transformer encoder.
num_attention_heads: Number of attention heads for each attention layer in
the Transformer encoder.
intermediate_size: The size of the "intermediate" (i.e., feed-forward)
layer in the Transformer encoder.
hidden_act: The non-linear activation function (function or string) in the
encoder and pooler. If string, "gelu", "relu" and "swish" are supported.
hidden_dropout_prob: The dropout probabilitiy for all fully connected
layers in the embeddings, encoder, and pooler.
attention_probs_dropout_prob: The dropout ratio for the attention
probabilities.
max_position_embeddings: The maximum sequence length that this model might
ever be used with. Typically set this to something large just in case
(e.g., 512 or 1024 or 2048).
type_vocab_size: The vocabulary size of the `token_type_ids` passed into
`BertModel`.
initializer_range: The sttdev of the truncated_normal_initializer for
initializing all weight matrices.
"""
if isinstance(vocab_size_or_config_json_file, str):
with open(vocab_size_or_config_json_file, "r", encoding='utf-8') as reader:
json_config = json.loads(reader.read())
for key, value in json_config.items():
self.__dict__[key] = value
elif isinstance(vocab_size_or_config_json_file, int):
self.vocab_size = vocab_size_or_config_json_file
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.hidden_act = hidden_act
self.intermediate_size = intermediate_size
self.hidden_dropout_prob = hidden_dropout_prob
self.attention_probs_dropout_prob = attention_probs_dropout_prob
self.max_position_embeddings = max_position_embeddings
self.type_vocab_size = type_vocab_size
self.initializer_range = initializer_range
self.deep_init = deep_init
self.fp32_layernorm = fp32_layernorm
self.fp32_embedding = fp32_embedding
self.layernorm_epsilon = layernorm_epsilon
self.fp32_tokentypes = fp32_tokentypes
else:
raise ValueError("First argument must be either a vocabulary size (int)"
"or the path to a pretrained model config file (str)")
@classmethod
def from_dict(cls, json_object):
"""Constructs a `BertConfig` from a Python dictionary of parameters."""
config = BertConfig(vocab_size_or_config_json_file=-1)
for key, value in json_object.items():
config.__dict__[key] = value
return config
@classmethod
def from_json_file(cls, json_file):
"""Constructs a `BertConfig` from a json file of parameters."""
with open(json_file, "r", encoding='utf-8') as reader:
text = reader.read()
return cls.from_dict(json.loads(text))
def __repr__(self):
return str(self.to_json_string())
def to_dict(self):
"""Serializes this instance to a Python dictionary."""
output = copy.deepcopy(self.__dict__)
return output
def to_json_string(self):
"""Serializes this instance to a JSON string."""
return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n"
try:
from apex.normalization.fused_layer_norm import FusedLayerNorm as BertLayerNorm
except ImportError:
print("Better speed can be achieved with apex installed from https://www.github.com/nvidia/apex.")
class BertLayerNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-12):
"""Construct a layernorm module in the TF style (epsilon inside the square root).
"""
super(BertLayerNorm, self).__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.bias = nn.Parameter(torch.zeros(hidden_size))
self.variance_epsilon = eps
def forward(self, x):
u = x.mean(-1, keepdim=True)
s = (x - u).pow(2).mean(-1, keepdim=True)
x = (x - u) / torch.sqrt(s + self.variance_epsilon)
return self.weight * x + self.bias
class BertEmbeddings(nn.Module):
"""Construct the embeddings from word, position and token_type embeddings.
"""
def __init__(self, config):
super(BertEmbeddings, self).__init__()
#self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size)
self.word_embeddings = mpu.VocabParallelEmbedding(
config.vocab_size, config.hidden_size,
init_method=normal_init_method(mean=0.0,
std=config.initializer_range))
self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
# self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
# any TensorFlow checkpoint file
self.fp32_layernorm = config.fp32_layernorm
self.fp32_embedding = config.fp32_embedding
self.fp32_tokentypes = config.fp32_tokentypes
self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layernorm_epsilon)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, input_ids, token_type_ids=None):
seq_length = input_ids.size(1)
position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device)
position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
if token_type_ids is None:
token_type_ids = torch.zeros_like(input_ids)
words_embeddings = self.word_embeddings(input_ids)
position_embeddings = self.position_embeddings(position_ids)
token_type_embeddings = self.token_type_embeddings(token_type_ids)
if not self.fp32_tokentypes:
embeddings = words_embeddings + position_embeddings + token_type_embeddings
if self.fp32_embedding and not self.fp32_layernorm:
embeddings = embeddings.half()
previous_type = embeddings.type()
if self.fp32_layernorm:
embeddings = embeddings.float()
embeddings = self.LayerNorm(embeddings)
if self.fp32_layernorm:
if self.fp32_embedding:
embeddings = embeddings.half()
else:
embeddings = embeddings.type(previous_type)
else:
embeddings = words_embeddings.float() + position_embeddings.float() + token_type_embeddings.float()
if self.fp32_tokentypes and not self.fp32_layernorm:
embeddings = embeddings.half()
previous_type = embeddings.type()
if self.fp32_layernorm:
embeddings = embeddings.float()
embeddings = self.LayerNorm(embeddings)
if self.fp32_layernorm:
if self.fp32_tokentypes:
embeddings = embeddings.half()
else:
embeddings = embeddings.type(previous_type)
embeddings = self.dropout(embeddings)
return embeddings
class BertSelfAttention(nn.Module):
def __init__(self, config):
super(BertSelfAttention, self).__init__()
if config.hidden_size % config.num_attention_heads != 0:
raise ValueError(
"The hidden size (%d) is not a multiple of the number of attention "
"heads (%d)" % (config.hidden_size, config.num_attention_heads))
self.num_attention_heads = config.num_attention_heads
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
self.all_head_size = self.num_attention_heads * self.attention_head_size
self.query = nn.Linear(config.hidden_size, self.all_head_size)
self.key = nn.Linear(config.hidden_size, self.all_head_size)
self.value = nn.Linear(config.hidden_size, self.all_head_size)
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
def transpose_for_scores(self, x):
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
x = x.view(*new_x_shape)
return x.permute(0, 2, 1, 3)
def forward(self, hidden_states, attention_mask):
mixed_query_layer = self.query(hidden_states)
mixed_key_layer = self.key(hidden_states)
mixed_value_layer = self.value(hidden_states)
query_layer = self.transpose_for_scores(mixed_query_layer)
key_layer = self.transpose_for_scores(mixed_key_layer)
value_layer = self.transpose_for_scores(mixed_value_layer)
# Take the dot product between "query" and "key" to get the raw attention scores.
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
# Apply the attention mask is (precomputed for all layers in BertModel forward() function)
attention_scores = attention_scores + attention_mask
# Normalize the attention scores to probabilities.
attention_probs = nn.Softmax(dim=-1)(attention_scores)
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
attention_probs = self.dropout(attention_probs)
previous_type = attention_probs.type()
context_layer = torch.matmul(attention_probs, value_layer)
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.view(*new_context_layer_shape)
return context_layer
class BertSelfOutput(nn.Module):
def __init__(self, config):
super(BertSelfOutput, self).__init__()
if hasattr(config, 'deep_init') and config.deep_init:
init_method = scaled_init_method(mean=0.0,
std=config.initializer_range,
num_layers=config.num_hidden_layers)
else:
init_method = normal_init_method(mean=0.0,
std=config.initializer_range)
self.dense = mpu.RowParallelLinear(
input_size=config.hidden_size,
output_size=config.hidden_size,
bias=True,
input_is_parallel=True,
stride=1,
init_method=init_method)
self.fp32_layernorm = config.fp32_layernorm
self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layernorm_epsilon)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states, input_tensor):
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states)
ln_input = hidden_states + input_tensor
previous_type = ln_input.type()
if self.fp32_layernorm:
ln_input = ln_input.float()
hidden_states = self.LayerNorm(ln_input)
if self.fp32_layernorm:
hidden_states = hidden_states.type(previous_type)
return hidden_states
class BertAttention(nn.Module):
def __init__(self, config):
super(BertAttention, self).__init__()
self.self = mpu.BertParallelSelfAttention(
hidden_size=config.hidden_size,
num_attention_heads=config.num_attention_heads,
dropout_prob=config.attention_probs_dropout_prob,
output_parallel=True,
init_method=normal_init_method(mean=0.0,
std=config.initializer_range))
self.output = BertSelfOutput(config)
def forward(self, input_tensor, attention_mask):
self_output = self.self(input_tensor, attention_mask)
attention_output = self.output(self_output, input_tensor)
return attention_output
class BertIntermediate(nn.Module):
def __init__(self, config):
super(BertIntermediate, self).__init__()
self.dense = mpu.ColumnParallelLinear(
input_size=config.hidden_size,
output_size=config.intermediate_size,
bias=True,
gather_output=False,
stride=1,
init_method=normal_init_method(mean=0.0,
std=config.initializer_range))
self.intermediate_act_fn = ACT2FN[config.hidden_act] \
if isinstance(config.hidden_act, str) else config.hidden_act
def forward(self, hidden_states):
hidden_states = self.dense(hidden_states)
hidden_states = self.intermediate_act_fn(hidden_states)
return hidden_states
class BertOutput(nn.Module):
def __init__(self, config):
super(BertOutput, self).__init__()
if hasattr(config, 'deep_init') and config.deep_init:
init_method = scaled_init_method(mean=0.0,
std=config.initializer_range,
num_layers=config.num_hidden_layers)
else:
init_method = normal_init_method(mean=0.0,
std=config.initializer_range)
self.dense = mpu.RowParallelLinear(
input_size=config.intermediate_size,
output_size=config.hidden_size,
bias=True,
input_is_parallel=True,
stride=1,
init_method=init_method)
self.fp32_layernorm = config.fp32_layernorm
self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layernorm_epsilon)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states, input_tensor):
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states)
ln_input = hidden_states + input_tensor
previous_type = ln_input.type()
if self.fp32_layernorm:
ln_input = ln_input.float()
hidden_states = self.LayerNorm(ln_input)
if self.fp32_layernorm:
hidden_states = hidden_states.type(previous_type)
return hidden_states
class BertLayer(nn.Module):
def __init__(self, config):
super(BertLayer, self).__init__()
self.attention = BertAttention(config)
self.intermediate = BertIntermediate(config)
self.output = BertOutput(config)
def forward(self, hidden_states, attention_mask):
attention_output = self.attention(hidden_states, attention_mask)
intermediate_output = self.intermediate(attention_output)
layer_output = self.output(intermediate_output, attention_output)
return layer_output
class BertEncoder(nn.Module):
def __init__(self, config):
super(BertEncoder, self).__init__()
#layer = BertLayer(config)
#self.layer = nn.ModuleList([copy.deepcopy(layer) for _ in range(config.num_hidden_layers)])
self.layer = nn.ModuleList([BertLayer(config) for _ in range(config.num_hidden_layers)])
# def forward(self, hidden_states, attention_mask, output_all_encoded_layers=True):
# all_encoder_layers = []
# for layer_module in self.layer:
# hidden_states = layer_module(hidden_states, attention_mask)
# if output_all_encoded_layers:
# all_encoder_layers.append(hidden_states)
# if not output_all_encoded_layers:
# all_encoder_layers.append(hidden_states)
# return all_encoder_layers
def forward(self, hidden_states, attention_mask, output_all_encoded_layers=True, checkpoint_activations=False):
all_encoder_layers = []
def custom(start, end):
def custom_forward(*inputs):
layers = self.layer[start:end]
x_ = inputs[0]
for layer in layers:
x_ = layer(x_, inputs[1])
return x_
return custom_forward
if checkpoint_activations:
l = 0
num_layers = len(self.layer)
chunk_length = 1 #math.ceil(math.sqrt(num_layers))
while l < num_layers:
hidden_states = mpu.checkpoint(custom(l, l+chunk_length), hidden_states, attention_mask*1)
l += chunk_length
# decoder layers
else:
for i,layer_module in enumerate(self.layer):
hidden_states = layer_module(hidden_states, attention_mask)
if output_all_encoded_layers:
all_encoder_layers.append(hidden_states)
if not output_all_encoded_layers or checkpoint_activations:
all_encoder_layers.append(hidden_states)
return all_encoder_layers
class BertPooler(nn.Module):
def __init__(self, config):
super(BertPooler, self).__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.activation = nn.Tanh()
def forward(self, hidden_states):
# We "pool" the model by simply taking the hidden state corresponding
# to the first token.
first_token_tensor = hidden_states[:, 0]
pooled_output = self.dense(first_token_tensor)
pooled_output = self.activation(pooled_output)
return pooled_output
class BertPredictionHeadTransform(nn.Module):
def __init__(self, config):
super(BertPredictionHeadTransform, self).__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.transform_act_fn = ACT2FN[config.hidden_act] \
if isinstance(config.hidden_act, str) else config.hidden_act
self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layernorm_epsilon)
self.fp32_layernorm = config.fp32_layernorm
def forward(self, hidden_states):
hidden_states = self.dense(hidden_states)
hidden_states = self.transform_act_fn(hidden_states)
previous_type = hidden_states.type()
if self.fp32_layernorm:
hidden_states = hidden_states.float()
hidden_states = self.LayerNorm(hidden_states)
if self.fp32_layernorm:
hidden_states = hidden_states.type(previous_type)
return hidden_states
class BertLMPredictionHead(nn.Module):
def __init__(self, config, bert_model_embedding_weights):
super(BertLMPredictionHead, self).__init__()
self.transform = BertPredictionHeadTransform(config)
# The output weights are the same as the input embeddings, but there is
# an output-only bias for each token.
#self.decoder = nn.Linear(bert_model_embedding_weights.size(1),
# bert_model_embedding_weights.size(0),
# bias=False)
self.decoder_weight = bert_model_embedding_weights
self.bias = nn.Parameter(torch.zeros(bert_model_embedding_weights.size(0)))
self.bias.model_parallel = True
self.fp32_embedding = config.fp32_embedding
self.fp32_layernorm = config.fp32_layernorm
def convert_to_type(tensor):
if self.fp32_embedding:
return tensor.half()
else:
return tensor
self.type_converter = convert_to_type
self.converted = False
def forward(self, hidden_states):
if not self.converted:
self.converted = True
if self.fp32_embedding:
self.transform.half()
if self.fp32_layernorm:
self.transform.LayerNorm.float()
hidden_states = self.transform(self.type_converter(hidden_states))
# hidden_states = self.decoder(hidden_states) + self.bias
hidden_states = mpu.copy_to_model_parallel_region(hidden_states)
hidden_states = F.linear(self.type_converter(hidden_states),
self.type_converter(self.decoder_weight),
self.type_converter(self.bias))
return hidden_states
class BertOnlyMLMHead(nn.Module):
def __init__(self, config, bert_model_embedding_weights):
super(BertOnlyMLMHead, self).__init__()
self.predictions = BertLMPredictionHead(config, bert_model_embedding_weights)
def forward(self, sequence_output):
prediction_scores = self.predictions(sequence_output)
return prediction_scores
class BertOnlyNSPHead(nn.Module):
def __init__(self, config):
super(BertOnlyNSPHead, self).__init__()
self.seq_relationship = nn.Linear(config.hidden_size, 2)
def forward(self, pooled_output):
seq_relationship_score = self.seq_relationship(pooled_output)
return seq_relationship_score
class BertPreTrainingHeads(nn.Module):
def __init__(self, config, bert_model_embedding_weights):
super(BertPreTrainingHeads, self).__init__()
self.predictions = BertLMPredictionHead(config, bert_model_embedding_weights)
self.seq_relationship = nn.Linear(config.hidden_size, 2)
def forward(self, sequence_output, pooled_output):
prediction_scores = self.predictions(sequence_output)
for p in self.seq_relationship.parameters():
if p is None:
continue
pooled_output = pooled_output.type_as(p)
seq_relationship_score = self.seq_relationship(pooled_output)
return prediction_scores, seq_relationship_score
class PreTrainedBertModel(nn.Module):
""" An abstract class to handle weights initialization and
a simple interface for dowloading and loading pretrained models.
"""
def __init__(self, config, *inputs, **kwargs):
super(PreTrainedBertModel, self).__init__()
if not isinstance(config, BertConfig):
raise ValueError(
"Parameter config in `{}(config)` should be an instance of class `BertConfig`. "
"To create a model from a Google pretrained model use "
"`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format(
self.__class__.__name__, self.__class__.__name__
))
self.config = config
def init_bert_weights(self, module):
""" Initialize the weights.
"""
if isinstance(module, (nn.Linear, nn.Embedding)):
# Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
elif isinstance(module, BertLayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
if isinstance(module, nn.Linear) and module.bias is not None:
module.bias.data.zero_()
@classmethod
def from_pretrained(cls, pretrained_model_name, state_dict=None, cache_dir=None,
fp32_layernorm=False, fp32_embedding=False, layernorm_epsilon=1e-12,
fp32_tokentypes=False, *inputs, **kwargs):
"""
Instantiate a PreTrainedBertModel from a pre-trained model file or a pytorch state dict.
Download and cache the pre-trained model file if needed.
Params:
pretrained_model_name: either:
- a str with the name of a pre-trained model to load selected in the list of:
. `bert-base-uncased`
. `bert-large-uncased`
. `bert-base-cased`
. `bert-large-cased`
. `bert-base-multilingual-uncased`
. `bert-base-multilingual-cased`
. `bert-base-chinese`
- a path or url to a pretrained model archive containing:
. `bert_config.json` a configuration file for the model
. `pytorch_model.bin` a PyTorch dump of a BertForPreTraining instance
cache_dir: an optional path to a folder in which the pre-trained models will be cached.
state_dict: an optional state dictionnary (collections.OrderedDict object) to use instead of Google pre-trained models
*inputs, **kwargs: additional input for the specific Bert class
(ex: num_labels for BertForSequenceClassification)
"""
if pretrained_model_name in PRETRAINED_MODEL_ARCHIVE_MAP:
archive_file = PRETRAINED_MODEL_ARCHIVE_MAP[pretrained_model_name]
else:
archive_file = pretrained_model_name
# redirect to the cache, if necessary
try:
resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir)
except FileNotFoundError:
logger.error(
"Model name '{}' was not found in model name list ({}). "
"We assumed '{}' was a path or url but couldn't find any file "
"associated to this path or url.".format(
pretrained_model_name,
', '.join(PRETRAINED_MODEL_ARCHIVE_MAP.keys()),
archive_file))
return None
if resolved_archive_file == archive_file:
logger.info("loading archive file {}".format(archive_file))
else:
logger.info("loading archive file {} from cache at {}".format(
archive_file, resolved_archive_file))
tempdir = None
if os.path.isdir(resolved_archive_file):
serialization_dir = resolved_archive_file
else:
# Extract archive to temp dir
tempdir = tempfile.mkdtemp()
logger.info("extracting archive file {} to temp dir {}".format(
resolved_archive_file, tempdir))
with tarfile.open(resolved_archive_file, 'r:gz') as archive:
archive.extractall(tempdir)
serialization_dir = tempdir
# Load config
config_file = os.path.join(serialization_dir, CONFIG_NAME)
config = BertConfig.from_json_file(config_file)
config.fp32_layernorm = fp32_layernorm
config.fp32_embedding = fp32_embedding
config.layernorm_epsilon = layernorm_epsilon
config.fp32_tokentypes = fp32_tokentypes
logger.info("Model config {}".format(config))
# Instantiate model.
model = cls(config, *inputs, **kwargs)
if state_dict is None:
weights_path = os.path.join(serialization_dir, WEIGHTS_NAME)
state_dict = torch.load(weights_path)
old_keys = []
new_keys = []
for key in state_dict.keys():
new_key = None
if 'gamma' in key:
new_key = key.replace('gamma', 'weight')
if 'beta' in key:
new_key = key.replace('beta', 'bias')
if new_key:
old_keys.append(key)
new_keys.append(new_key)
for old_key, new_key in zip(old_keys, new_keys):
state_dict[new_key] = state_dict.pop(old_key)
missing_keys = []
unexpected_keys = []
error_msgs = []
# copy state_dict so _load_from_state_dict can modify it
metadata = getattr(state_dict, '_metadata', None)
state_dict = state_dict.copy()
if metadata is not None:
state_dict._metadata = metadata
def load(module, prefix=''):
local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
module._load_from_state_dict(
state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs)
for name, child in module._modules.items():
if child is not None:
load(child, prefix + name + '.')
load(model, prefix='' if hasattr(model, 'bert') else 'bert.')
if len(missing_keys) > 0:
logger.info("Weights of {} not initialized from pretrained model: {}".format(
model.__class__.__name__, missing_keys))
if len(unexpected_keys) > 0:
logger.info("Weights from pretrained model not used in {}: {}".format(
model.__class__.__name__, unexpected_keys))
if tempdir:
# Clean up temp dir
shutil.rmtree(tempdir)
return model
class BertModel(PreTrainedBertModel):
"""BERT model ("Bidirectional Embedding Representations from a Transformer").
Params:
config: a BertConfig class instance with the configuration to build a new model
Inputs:
`input_ids`: a torch.LongTensor of shape [batch_size, sequence_length]
with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts
`extract_features.py`, `run_classifier.py` and `run_squad.py`)
`token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token
types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to
a `sentence B` token (see BERT paper for more details).
`attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices
selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max
input sequence length in the current batch. It's the mask that we typically use for attention when
a batch has varying length sentences.
`output_all_encoded_layers`: boolean which controls the content of the `encoded_layers` output as described below. Default: `True`.
Outputs: Tuple of (encoded_layers, pooled_output)
`encoded_layers`: controled by `output_all_encoded_layers` argument:
- `output_all_encoded_layers=True`: outputs a list of the full sequences of encoded-hidden-states at the end
of each attention block (i.e. 12 full sequences for BERT-base, 24 for BERT-large), each
encoded-hidden-state is a torch.FloatTensor of size [batch_size, sequence_length, hidden_size],
- `output_all_encoded_layers=False`: outputs only the full sequence of hidden-states corresponding
to the last attention block of shape [batch_size, sequence_length, hidden_size],
`pooled_output`: a torch.FloatTensor of size [batch_size, hidden_size] which is the output of a
classifier pretrained on top of the hidden state associated to the first character of the
input (`CLF`) to train on the Next-Sentence task (see BERT's paper).
Example usage:
```python
# Already been converted into WordPiece token ids
input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]])
config = modeling.BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768,
num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072)
model = modeling.BertModel(config=config)
all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask)
```
"""
def __init__(self, config):
super(BertModel, self).__init__(config)
self.embeddings = BertEmbeddings(config)
self.encoder = BertEncoder(config)
self.pooler = BertPooler(config)
self.apply(self.init_bert_weights)
def forward(self, input_ids, token_type_ids=None, attention_mask=None, output_all_encoded_layers=True, checkpoint_activations=False):
if attention_mask is None:
attention_mask = torch.ones_like(input_ids)
if token_type_ids is None:
token_type_ids = torch.zeros_like(input_ids)
# We create a 3D attention mask from a 2D tensor mask.
# Sizes are [batch_size, 1, 1, to_seq_length]
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
# this attention mask is more simple than the triangular masking of causal attention
# used in OpenAI GPT, we just need to prepare the broadcast dimension here.
extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
# masked positions, this operation will create a tensor which is 0.0 for
# positions we want to attend and -10000.0 for masked positions.
# Since we are adding it to the raw scores before the softmax, this is
# effectively the same as removing these entirely.
extended_attention_mask = extended_attention_mask.to(dtype=next(self.encoder.parameters()).dtype) # fp16 compatibility
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
embedding_output = self.embeddings(input_ids, token_type_ids)
encoded_layers = self.encoder(embedding_output,
extended_attention_mask,
output_all_encoded_layers=output_all_encoded_layers,
checkpoint_activations=checkpoint_activations)
sequence_output = encoded_layers[-1]
for p in self.pooler.parameters():
if p is None:
continue
sequence_output = sequence_output.type_as(p)
break
pooled_output = self.pooler(sequence_output)
if not output_all_encoded_layers or checkpoint_activations:
encoded_layers = encoded_layers[-1]
return encoded_layers, pooled_output
class BertForPreTraining(PreTrainedBertModel):
"""BERT model with pre-training heads.
This module comprises the BERT model followed by the two pre-training heads:
- the masked language modeling head, and
- the next sentence classification head.
Params:
config: a BertConfig class instance with the configuration to build a new model.
Inputs:
`input_ids`: a torch.LongTensor of shape [batch_size, sequence_length]
with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts
`extract_features.py`, `run_classifier.py` and `run_squad.py`)
`token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token
types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to
a `sentence B` token (see BERT paper for more details).
`attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices
selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max
input sequence length in the current batch. It's the mask that we typically use for attention when
a batch has varying length sentences.
`masked_lm_labels`: masked language modeling labels: torch.LongTensor of shape [batch_size, sequence_length]
with indices selected in [-1, 0, ..., vocab_size]. All labels set to -1 are ignored (masked), the loss
is only computed for the labels set in [0, ..., vocab_size]
`next_sentence_label`: next sentence classification loss: torch.LongTensor of shape [batch_size]
with indices selected in [0, 1].
0 => next sentence is the continuation, 1 => next sentence is a random sentence.
Outputs:
if `masked_lm_labels` and `next_sentence_label` are not `None`:
Outputs the total_loss which is the sum of the masked language modeling loss and the next
sentence classification loss.
if `masked_lm_labels` or `next_sentence_label` is `None`:
Outputs a tuple comprising
- the masked language modeling logits of shape [batch_size, sequence_length, vocab_size], and
- the next sentence classification logits of shape [batch_size, 2].
Example usage:
```python
# Already been converted into WordPiece token ids
input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]])
config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768,
num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072)
model = BertForPreTraining(config)
masked_lm_logits_scores, seq_relationship_logits = model(input_ids, token_type_ids, input_mask)
```
"""
def __init__(self, config):
super(BertForPreTraining, self).__init__(config)
self.bert = BertModel(config)
self.cls = BertPreTrainingHeads(config, self.bert.embeddings.word_embeddings.weight)
self.apply(self.init_bert_weights)
def forward(self, input_ids, token_type_ids=None, attention_mask=None, masked_lm_labels=None, next_sentence_label=None, checkpoint_activations=False):
sequence_output, pooled_output = self.bert(input_ids, token_type_ids, attention_mask,
output_all_encoded_layers=False, checkpoint_activations=checkpoint_activations)
prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output)
if masked_lm_labels is not None and next_sentence_label is not None:
loss_fct = CrossEntropyLoss(ignore_index=-1)
masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size).float(), masked_lm_labels.view(-1))
next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2).float(), next_sentence_label.view(-1))
total_loss = masked_lm_loss + next_sentence_loss
return total_loss
else:
return prediction_scores, seq_relationship_score
class BertForMaskedLM(PreTrainedBertModel):
"""BERT model with the masked language modeling head.
This module comprises the BERT model followed by the masked language modeling head.
Params:
config: a BertConfig class instance with the configuration to build a new model.
Inputs:
`input_ids`: a torch.LongTensor of shape [batch_size, sequence_length]
with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts
`extract_features.py`, `run_classifier.py` and `run_squad.py`)
`token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token
types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to
a `sentence B` token (see BERT paper for more details).
`attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices
selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max
input sequence length in the current batch. It's the mask that we typically use for attention when
a batch has varying length sentences.
`masked_lm_labels`: masked language modeling labels: torch.LongTensor of shape [batch_size, sequence_length]
with indices selected in [-1, 0, ..., vocab_size]. All labels set to -1 are ignored (masked), the loss
is only computed for the labels set in [0, ..., vocab_size]
Outputs:
if `masked_lm_labels` is not `None`:
Outputs the masked language modeling loss.
if `masked_lm_labels` is `None`:
Outputs the masked language modeling logits of shape [batch_size, sequence_length, vocab_size].
Example usage:
```python
# Already been converted into WordPiece token ids
input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]])
config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768,
num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072)
model = BertForMaskedLM(config)
masked_lm_logits_scores = model(input_ids, token_type_ids, input_mask)
```
"""
def __init__(self, config):
super(BertForMaskedLM, self).__init__(config)
self.bert = BertModel(config)
self.cls = BertOnlyMLMHead(config, self.bert.embeddings.word_embeddings.weight)
self.apply(self.init_bert_weights)
def forward(self, input_ids, token_type_ids=None, attention_mask=None, masked_lm_labels=None, checkpoint_activations=False):
sequence_output, _ = self.bert(input_ids, token_type_ids, attention_mask,
output_all_encoded_layers=False, checkpoint_activations=checkpoint_activations)
prediction_scores = self.cls(sequence_output)
if masked_lm_labels is not None:
loss_fct = CrossEntropyLoss(ignore_index=-1)
masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), masked_lm_labels.view(-1))
return masked_lm_loss
else:
return prediction_scores
class BertForNextSentencePrediction(PreTrainedBertModel):
"""BERT model with next sentence prediction head.
This module comprises the BERT model followed by the next sentence classification head.
Params:
config: a BertConfig class instance with the configuration to build a new model.
Inputs:
`input_ids`: a torch.LongTensor of shape [batch_size, sequence_length]
with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts
`extract_features.py`, `run_classifier.py` and `run_squad.py`)
`token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token
types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to
a `sentence B` token (see BERT paper for more details).
`attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices
selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max
input sequence length in the current batch. It's the mask that we typically use for attention when
a batch has varying length sentences.
`next_sentence_label`: next sentence classification loss: torch.LongTensor of shape [batch_size]
with indices selected in [0, 1].
0 => next sentence is the continuation, 1 => next sentence is a random sentence.
Outputs:
if `next_sentence_label` is not `None`:
Outputs the total_loss which is the sum of the masked language modeling loss and the next
sentence classification loss.
if `next_sentence_label` is `None`:
Outputs the next sentence classification logits of shape [batch_size, 2].
Example usage:
```python
# Already been converted into WordPiece token ids
input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]])
config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768,
num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072)
model = BertForNextSentencePrediction(config)
seq_relationship_logits = model(input_ids, token_type_ids, input_mask)
```
"""
def __init__(self, config):
super(BertForNextSentencePrediction, self).__init__(config)
self.bert = BertModel(config)
self.cls = BertOnlyNSPHead(config)
self.apply(self.init_bert_weights)
def forward(self, input_ids, token_type_ids=None, attention_mask=None, next_sentence_label=None, checkpoint_activations=False):
_, pooled_output = self.bert(input_ids, token_type_ids, attention_mask,
output_all_encoded_layers=False, checkpoint_activations=checkpoint_activations)
seq_relationship_score = self.cls( pooled_output)
if next_sentence_label is not None:
loss_fct = CrossEntropyLoss(ignore_index=-1)
next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), next_sentence_label.view(-1))
return next_sentence_loss
else:
return seq_relationship_score
class BertForSequenceClassification(PreTrainedBertModel):
"""BERT model for classification.
This module is composed of the BERT model with a linear layer on top of
the pooled output.
Params:
`config`: a BertConfig class instance with the configuration to build a new model.
`num_labels`: the number of classes for the classifier. Default = 2.
Inputs:
`input_ids`: a torch.LongTensor of shape [batch_size, sequence_length]
with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts
`extract_features.py`, `run_classifier.py` and `run_squad.py`)
`token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token
types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to
a `sentence B` token (see BERT paper for more details).
`attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices
selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max
input sequence length in the current batch. It's the mask that we typically use for attention when
a batch has varying length sentences.
`labels`: labels for the classification output: torch.LongTensor of shape [batch_size]
with indices selected in [0, ..., num_labels].
Outputs:
if `labels` is not `None`:
Outputs the CrossEntropy classification loss of the output with the labels.
if `labels` is `None`:
Outputs the classification logits of shape [batch_size, num_labels].
Example usage:
```python
# Already been converted into WordPiece token ids
input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]])
config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768,
num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072)
num_labels = 2
model = BertForSequenceClassification(config, num_labels)
logits = model(input_ids, token_type_ids, input_mask)
```
"""
def __init__(self, config, num_labels=2):
super(BertForSequenceClassification, self).__init__(config)
self.num_labels = num_labels
self.bert = BertModel(config)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.classifier = nn.Linear(config.hidden_size, num_labels)
self.apply(self.init_bert_weights)
def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None, checkpoint_activations=False):
_, pooled_output = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False, checkpoint_activations=checkpoint_activations)
pooled_output = self.dropout(pooled_output)
logits = self.classifier(pooled_output)
if labels is not None:
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
return loss
else:
return logits
class BertForMultipleChoice(PreTrainedBertModel):
"""BERT model for multiple choice tasks.
This module is composed of the BERT model with a linear layer on top of
the pooled output.
Params:
`config`: a BertConfig class instance with the configuration to build a new model.
`num_choices`: the number of classes for the classifier. Default = 2.
Inputs:
`input_ids`: a torch.LongTensor of shape [batch_size, num_choices, sequence_length]
with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts
`extract_features.py`, `run_classifier.py` and `run_squad.py`)
`token_type_ids`: an optional torch.LongTensor of shape [batch_size, num_choices, sequence_length]
with the token types indices selected in [0, 1]. Type 0 corresponds to a `sentence A`
and type 1 corresponds to a `sentence B` token (see BERT paper for more details).
`attention_mask`: an optional torch.LongTensor of shape [batch_size, num_choices, sequence_length] with indices
selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max
input sequence length in the current batch. It's the mask that we typically use for attention when
a batch has varying length sentences.
`labels`: labels for the classification output: torch.LongTensor of shape [batch_size]
with indices selected in [0, ..., num_choices].
Outputs:
if `labels` is not `None`:
Outputs the CrossEntropy classification loss of the output with the labels.
if `labels` is `None`:
Outputs the classification logits of shape [batch_size, num_labels].
Example usage:
```python
# Already been converted into WordPiece token ids
input_ids = torch.LongTensor([[[31, 51, 99], [15, 5, 0]], [[12, 16, 42], [14, 28, 57]]])
input_mask = torch.LongTensor([[[1, 1, 1], [1, 1, 0]],[[1,1,0], [1, 0, 0]]])
token_type_ids = torch.LongTensor([[[0, 0, 1], [0, 1, 0]],[[0, 1, 1], [0, 0, 1]]])
config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768,
num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072)
num_choices = 2
model = BertForMultipleChoice(config, num_choices)
logits = model(input_ids, token_type_ids, input_mask)
```
"""
def __init__(self, config, num_choices=2):
super(BertForMultipleChoice, self).__init__(config)
self.num_choices = num_choices
self.bert = BertModel(config)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.classifier = nn.Linear(config.hidden_size, 1)
self.apply(self.init_bert_weights)
def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None, checkpoint_activations=False):
flat_input_ids = input_ids.view(-1, input_ids.size(-1))
flat_token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1))
flat_attention_mask = attention_mask.view(-1, attention_mask.size(-1))
_, pooled_output = self.bert(flat_input_ids, flat_token_type_ids, flat_attention_mask, output_all_encoded_layers=False, checkpoint_activations=checkpoint_activations)
pooled_output = self.dropout(pooled_output)
logits = self.classifier(pooled_output)
reshaped_logits = logits.view(-1, self.num_choices)
if labels is not None:
loss_fct = CrossEntropyLoss()
loss = loss_fct(reshaped_logits, labels)
return loss
else:
return reshaped_logits
class BertForTokenClassification(PreTrainedBertModel):
"""BERT model for token-level classification.
This module is composed of the BERT model with a linear layer on top of
the full hidden state of the last layer.
Params:
`config`: a BertConfig class instance with the configuration to build a new model.
`num_labels`: the number of classes for the classifier. Default = 2.
Inputs:
`input_ids`: a torch.LongTensor of shape [batch_size, sequence_length]
with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts
`extract_features.py`, `run_classifier.py` and `run_squad.py`)
`token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token
types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to
a `sentence B` token (see BERT paper for more details).
`attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices
selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max
input sequence length in the current batch. It's the mask that we typically use for attention when
a batch has varying length sentences.
`labels`: labels for the classification output: torch.LongTensor of shape [batch_size]
with indices selected in [0, ..., num_labels].
Outputs:
if `labels` is not `None`:
Outputs the CrossEntropy classification loss of the output with the labels.
if `labels` is `None`:
Outputs the classification logits of shape [batch_size, sequence_length, num_labels].
Example usage:
```python
# Already been converted into WordPiece token ids
input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]])
config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768,
num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072)
num_labels = 2
model = BertForTokenClassification(config, num_labels)
logits = model(input_ids, token_type_ids, input_mask)
```
"""
def __init__(self, config, num_labels=2):
super(BertForTokenClassification, self).__init__(config)
self.num_labels = num_labels
self.bert = BertModel(config)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
#self.classifier = nn.Linear(config.hidden_size, num_labels)
self.classifier = mpu.RowParallelLinear(
input_size=config.hidden_size,
output_size=num_labels,
bias=True,
input_is_parallel=True,
stride=1,
init_method=normal_init_method(mean=0.0,
std=config.initializer_range))
self.apply(self.init_bert_weights)
def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None, checkpoint_activations=False):
sequence_output, _ = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False, checkpoint_activations=checkpoint_activations)
with mpu.get_cuda_rng_tracker().fork():
sequence_output = self.dropout(sequence_output)
logits = self.classifier(sequence_output)
if labels is not None:
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
return loss
else:
return logits
class BertForQuestionAnswering(PreTrainedBertModel):
"""BERT model for Question Answering (span extraction).
This module is composed of the BERT model with a linear layer on top of
the sequence output that computes start_logits and end_logits
Params:
`config`: a BertConfig class instance with the configuration to build a new model.
Inputs:
`input_ids`: a torch.LongTensor of shape [batch_size, sequence_length]
with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts
`extract_features.py`, `run_classifier.py` and `run_squad.py`)
`token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token
types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to
a `sentence B` token (see BERT paper for more details).
`attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices
selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max
input sequence length in the current batch. It's the mask that we typically use for attention when
a batch has varying length sentences.
`start_positions`: position of the first token for the labeled span: torch.LongTensor of shape [batch_size].
Positions are clamped to the length of the sequence and position outside of the sequence are not taken
into account for computing the loss.
`end_positions`: position of the last token for the labeled span: torch.LongTensor of shape [batch_size].
Positions are clamped to the length of the sequence and position outside of the sequence are not taken
into account for computing the loss.
Outputs:
if `start_positions` and `end_positions` are not `None`:
Outputs the total_loss which is the sum of the CrossEntropy loss for the start and end token positions.
if `start_positions` or `end_positions` is `None`:
Outputs a tuple of start_logits, end_logits which are the logits respectively for the start and end
position tokens of shape [batch_size, sequence_length].
Example usage:
```python
# Already been converted into WordPiece token ids
input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]])
config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768,
num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072)
model = BertForQuestionAnswering(config)
start_logits, end_logits = model(input_ids, token_type_ids, input_mask)
```
"""
def __init__(self, config):
super(BertForQuestionAnswering, self).__init__(config)
self.bert = BertModel(config)
# TODO check with Google if it's normal there is no dropout on the token classifier of SQuAD in the TF version
# self.dropout = nn.Dropout(config.hidden_dropout_prob)
#self.qa_outputs = nn.Linear(config.hidden_size, 2)
self.qa_outputs = mpu.RowParallelLinear(
input_size=config.hidden_size,
output_size=2,
bias=True,
input_is_parallel=True,
stride=1,
init_method=normal_init_method(mean=0.0,
std=config.initializer_range))
self.apply(self.init_bert_weights)
def forward(self, input_ids, token_type_ids=None, attention_mask=None, start_positions=None, end_positions=None, checkpoint_activations=False):
sequence_output, _ = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False, checkpoint_activations=checkpoint_activations)
logits = self.qa_outputs(sequence_output)
start_logits, end_logits = logits.split(1, dim=-1)
start_logits = start_logits.squeeze(-1)
end_logits = end_logits.squeeze(-1)
if start_positions is not None and end_positions is not None:
# If we are on multi-GPU, split add a dimension
if len(start_positions.size()) > 1:
start_positions = start_positions.squeeze(-1)
if len(end_positions.size()) > 1:
end_positions = end_positions.squeeze(-1)
# sometimes the start/end positions are outside our model inputs, we ignore these terms
ignored_index = start_logits.size(1)
start_positions.clamp_(0, ignored_index)
end_positions.clamp_(0, ignored_index)
loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
start_loss = loss_fct(start_logits, start_positions)
end_loss = loss_fct(end_logits, end_positions)
total_loss = (start_loss + end_loss) / 2
return total_loss
else:
return start_logits, end_logits
class PrefixTuning(nn.Module):
def __init__(self, model, num_layer, embd_dim, device, mid_dim=None, preseqlen=100, prefix_dropout=0.0):
super(PrefixTuning, self).__init__()
self.device = device
self.match_n_layer = num_layer
self.n_embd = embd_dim
self.mid_dim = mid_dim if mid_dim is not None else embd_dim * 4
self.preseqlen = preseqlen
self.prefix_dropout = prefix_dropout
self.model = model
print('PrefixTuning')
print('preseqlen is {}, optimizing the prefix directly'.format(self.preseqlen))
self.input_tokens = torch.arange(self.preseqlen).long()
self.wte = nn.Embedding(self.preseqlen, self.n_embd)
# multiply by 2 because we need to split to kv pair in the end
self.control_trans = nn.Sequential(
nn.Linear(self.n_embd, self.mid_dim),
nn.Tanh(),
nn.Linear(self.mid_dim, self.match_n_layer * self.n_embd))
self.dropout = nn.Dropout(self.prefix_dropout)
# freeze all model params so that only prefixes are trained!!
for param in model.parameters():
param.requires_grad = False
def get_prompt(self, bsz):
prefix_tokens = self.input_tokens.unsqueeze(0).expand(bsz, -1).to(self.device)
prefix_embeddings = self.wte(prefix_tokens)
prefixes = self.control_trans(prefix_embeddings) #bsz, seqlen, layer*emb
bsz, seqlen, _ = prefixes.shape
prefixes = prefixes.view(bsz, seqlen, self.match_n_layer , self.n_embd)
prefixes = self.dropout(prefixes)
# reorder and split to kv pairs
prefixes = prefixes.permute([2, 0, 1, 3])
new_mems = [prefix_embeddings]
for layer_mem in prefixes:
new_mems.append(layer_mem)
return new_mems
def forward(self, input_ids, position_ids, attention_mask, *mems):
prefixes = self.get_prompt(bsz=input_ids.shape[0]) # num_layer * bsz * seq_len * hidden_dim
if mems:
new_mems = torch.cat([prefixes, mems], dim=2)
else:
new_mems = prefixes
return self.model(input_ids, position_ids, attention_mask, *new_mems)
\ No newline at end of file
# coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Model parallel utility interface."""
from .cross_entropy import vocab_parallel_cross_entropy
from .data import broadcast_data
from .grads import clip_grad_norm
from .initialize import destroy_model_parallel
from .initialize import get_data_parallel_group
from .initialize import get_data_parallel_rank
from .initialize import get_data_parallel_world_size
from .initialize import get_model_parallel_group
from .initialize import get_model_parallel_rank
from .initialize import get_model_parallel_src_rank
from .initialize import get_model_parallel_world_size
from .initialize import initialize_model_parallel
from .initialize import model_parallel_is_initialized
from .layers import ColumnParallelLinear
from .layers import ParallelEmbedding
from .layers import RowParallelLinear
from .layers import VocabParallelEmbedding
from .mappings import copy_to_model_parallel_region
from .mappings import gather_from_model_parallel_region
from .mappings import reduce_from_model_parallel_region
from .mappings import scatter_to_model_parallel_region
from .random import checkpoint
from .random import partition_activations_in_checkpoint
from .random import get_cuda_rng_tracker
from .random import model_parallel_cuda_manual_seed
from .transformer import BertParallelSelfAttention
from .transformer import BertParallelTransformerLayer
from .transformer import GPT2ParallelTransformer
from .transformer import LayerNorm
# coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from .initialize import get_model_parallel_group
from .initialize import get_model_parallel_rank
from .initialize import get_model_parallel_world_size
from .utils import VocabUtility
class _VocabParallelCrossEntropy(torch.autograd.Function):
@staticmethod
def forward(ctx, vocab_parallel_logits, target):
# Copy so the input remains unchanged.
logits = vocab_parallel_logits.clone()
# Maximum value along vocab dimension across all GPUs.
logits_max = torch.max(logits, dim=-1)[0]
torch.distributed.all_reduce(logits_max,
op=torch.distributed.ReduceOp.MAX,
group=get_model_parallel_group())
# Subtract the maximum value.
logits.sub_(logits_max.unsqueeze(dim=-1))
# Sum of exponential of logits along vocab dimension across all GPUs.
exp_logits = logits.exp()
sum_exp_logits = exp_logits.sum(dim=-1)
torch.distributed.all_reduce(sum_exp_logits,
op=torch.distributed.ReduceOp.SUM,
group=get_model_parallel_group())
# Get the partition's vocab indecies
get_vocab_range = VocabUtility.vocab_range_from_per_partition_vocab_size
partition_vocab_size = vocab_parallel_logits.size()[-1]
rank = get_model_parallel_rank()
world_size = get_model_parallel_world_size()
vocab_start_index, vocab_end_index = get_vocab_range(
partition_vocab_size, rank, world_size)
# Create a mask of valid vocab ids (1 means it needs to be masked).
target_mask = (target < vocab_start_index) | (target >= vocab_end_index)
masked_target = target.clone() - vocab_start_index
masked_target[target_mask] = 0
# Get predicted-logits = logits[target].
# For Simplicity, we convert logits to a 2-D tensor with size
# [*, partition-vocab-size] and target to a 1-D tensor of size [*].
logits_2d = logits.view(-1, partition_vocab_size)
masked_target_1d = masked_target.view(-1)
arange_1d = torch.arange(start=0, end=logits_2d.size()[0],
device=logits_2d.device)
predicted_logits_1d = logits_2d[arange_1d, masked_target_1d]
predicted_logits = predicted_logits_1d.view_as(target)
predicted_logits[target_mask] = 0.0
# All reduce is needed to get the chunks from other GPUs.
torch.distributed.all_reduce(predicted_logits,
op=torch.distributed.ReduceOp.SUM,
group=get_model_parallel_group())
# Loss = log(sum(exp(logits))) - predicted-logit.
loss = torch.log(sum_exp_logits) - predicted_logits
# Store softmax, target-mask and masked-target for backward pass.
exp_logits.div_(sum_exp_logits.unsqueeze(dim=-1))
ctx.save_for_backward(exp_logits, target_mask, masked_target_1d)
return loss
@staticmethod
def backward(ctx, grad_output):
# Retreive tensors from the forward path.
softmax, target_mask, masked_target_1d = ctx.saved_tensors
# All the inputs have softmax as thier gradient.
grad_input = softmax
# For simplicity, work with the 2D gradient.
partition_vocab_size = softmax.size()[-1]
grad_2d = grad_input.view(-1, partition_vocab_size)
# Add the gradient from matching classes.
arange_1d = torch.arange(start=0, end=grad_2d.size()[0],
device=grad_2d.device)
grad_2d[arange_1d, masked_target_1d] -= (
1.0 - target_mask.view(-1).float())
# Finally elementwise multiplication with the output gradients.
grad_input.mul_(grad_output.unsqueeze(dim=-1))
return grad_input, None
def vocab_parallel_cross_entropy(vocab_parallel_logits, target):
"""Helper function for the cross entropy."""
return _VocabParallelCrossEntropy.apply(vocab_parallel_logits, target)
# coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from .initialize import get_model_parallel_group
from .initialize import get_model_parallel_rank
from .initialize import get_model_parallel_src_rank
_MAX_DATA_DIM = 4
def _check_data_types(keys, data, target_dtype):
"""Check that all the keys have the same target data type."""
for key in keys:
assert data[key].dtype == target_dtype, '{} has data type {} which '\
'is different than {}'.format(key, data[key].dtype, target_dtype)
def _build_key_size_numel_dictionaries(keys, data):
"""Build the size on rank 0 and broadcast."""
max_dim = _MAX_DATA_DIM
sizes = [0 for _ in range(max_dim) for _ in keys]
# Pack the sizes on rank zero.
if get_model_parallel_rank() == 0:
offset = 0
for key in keys:
assert data[key].dim() < max_dim, 'you should increase MAX_DATA_DIM'
size = data[key].size()
for i, s in enumerate(size):
sizes[i + offset] = s
offset += max_dim
# Move to GPU and broadcast.
sizes_cuda = torch.cuda.LongTensor(sizes)
torch.distributed.broadcast(sizes_cuda, get_model_parallel_src_rank(),
group=get_model_parallel_group())
# Move back to cpu and unpack.
sizes_cpu = sizes_cuda.cpu()
key_size = {}
key_numel = {}
total_numel = 0
offset = 0
for key in keys:
i = 0
size = []
numel = 1
while sizes_cpu[offset + i] > 0:
this_size = sizes_cpu[offset + i]
size.append(this_size)
numel *= this_size
i += 1
key_size[key] = size
key_numel[key] = numel
total_numel += numel
offset += max_dim
return key_size, key_numel, total_numel
def broadcast_data(keys, data, datatype):
"""Broadcast data from rank zero of each model parallel group to the
members of the same model parallel group.
Arguments:
keys: list of keys in the data disctionary to be broadcasted
data: data dictionary of string keys and cpu tensor values.
datatype: torch data type of all tensors in data associated
with keys.
"""
# Build (key, size) and (key, number of elements) dictionaries along
# with the total number of elements on all ranks.
key_size, key_numel, total_numel = _build_key_size_numel_dictionaries(keys,
data)
# Pack on rank zero.
if get_model_parallel_rank() == 0:
# Check that all keys have the same data type.
_check_data_types(keys, data, datatype)
# Flatten the data associated with the keys
flatten_data = torch.cat(
[data[key].contiguous().view(-1) for key in keys], dim=0).cuda()
else:
flatten_data = torch.empty(total_numel,
device=torch.cuda.current_device(),
dtype=datatype)
# Boradcast
torch.distributed.broadcast(flatten_data, get_model_parallel_src_rank(),
group=get_model_parallel_group())
# Unpack
output = {}
offset = 0
for key in keys:
size = key_size[key]
numel = key_numel[key]
output[key] = flatten_data.narrow(0, offset, numel).view(size)
offset += numel
return output
# coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Parts of the code here are adapted from PyTorch
# repo: https://github.com/pytorch/pytorch
import torch
from torch._six import inf
from .initialize import get_model_parallel_group
from .initialize import get_model_parallel_rank
def clip_grad_norm(parameters, max_norm, norm_type=2):
"""Clips gradient norm of an iterable of parameters.
This is adapted from torch.nn.utils.clip_grad.clip_grad_norm_ and
added functionality to handle model parallel parameters. Note that
the gradients are modified in place.
Arguments:
parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a
single Tensor that will have gradients normalized
max_norm (float or int): max norm of the gradients
norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for
infinity norm.
Returns:
Total norm of the parameters (viewed as a single vector).
"""
if isinstance(parameters, torch.Tensor):
parameters = [parameters]
parameters = list(filter(lambda p: p.grad is not None, parameters))
max_norm = float(max_norm)
norm_type = float(norm_type)
if norm_type == inf:
total_norm = max(p.grad.data.abs().max() for p in parameters)
total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)])
# Take max across all GPUs.
torch.distributed.all_reduce(total_norm_cuda,
op=torch.distributed.ReduceOp.MAX,
group=get_model_parallel_group())
total_norm = total_norm_cuda[0].item()
else:
total_norm = 0
for p in parameters:
if p.model_parallel or (get_model_parallel_rank() == 0):
param_norm = p.grad.data.norm(norm_type)
total_norm += param_norm.item() ** norm_type
# Sum across all model parallel GPUs.
total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)])
torch.distributed.all_reduce(total_norm_cuda,
op=torch.distributed.ReduceOp.SUM,
group=get_model_parallel_group())
total_norm = total_norm_cuda[0].item() ** (1. / norm_type)
clip_coef = max_norm / (total_norm + 1e-6)
if clip_coef < 1:
for p in parameters:
p.grad.data.mul_(clip_coef)
return total_norm
# coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Model and data parallel groups."""
import torch
from .utils import ensure_divisibility
# Model parallel group that the current rank belongs to.
_MODEL_PARALLEL_GROUP = None
# Data parallel group that the current rank belongs to.
_DATA_PARALLEL_GROUP = None
def initialize_model_parallel(model_parallel_size_):
"""
Initialize model data parallel groups.
Arguments:
model_parallel_size: number of GPUs used to parallelize model.
Let's say we have a total of 8 GPUs denoted by g0 ... g7 and we
use 2 GPUs to parallelize the model. The present function will
create 4 model parallel groups and 2 data parallel grous as:
4 model parallel groups:
[g0, g1], [g2, g3], [g4, g5], [g6, g7]
2 data parallel groups:
[g0, g2, g4, g6], [g1, g3, g5, g7]
Note that for efficiency, the caller should make sure adjacent ranks
are on the same DGX box. For example if we are using 2 DGX-1 boxes
with a total of 16 GPUs, rank 0 to 7 belong to the first box and
ranks 8 to 15 belong to the second box.
"""
if torch.distributed.get_rank() == 0:
print('> initializing model parallel with size {}'.format(
model_parallel_size_))
# Get world size and rank. Ensure some consistencies.
assert torch.distributed.is_initialized()
world_size = torch.distributed.get_world_size()
model_parallel_size = min(model_parallel_size_, world_size)
ensure_divisibility(world_size, model_parallel_size)
rank = torch.distributed.get_rank()
# Build the data parallel groups.
global _DATA_PARALLEL_GROUP
assert _DATA_PARALLEL_GROUP is None, \
'data parallel group is already initialized'
for i in range(model_parallel_size):
ranks = range(i, world_size, model_parallel_size)
group = torch.distributed.new_group(ranks)
if i == (rank % model_parallel_size):
_DATA_PARALLEL_GROUP = group
# Build the model parallel groups.
global _MODEL_PARALLEL_GROUP
assert _MODEL_PARALLEL_GROUP is None, \
'model parallel group is already initialized'
for i in range(world_size // model_parallel_size):
ranks = range(i * model_parallel_size,
(i + 1) * model_parallel_size)
group = torch.distributed.new_group(ranks)
if i == (rank // model_parallel_size):
_MODEL_PARALLEL_GROUP = group
def model_parallel_is_initialized():
"""Check if model and data parallel groups are initialized."""
if _MODEL_PARALLEL_GROUP is None or _DATA_PARALLEL_GROUP is None:
return False
return True
def get_model_parallel_group():
"""Get the model parallel group the caller rank belongs to."""
assert _MODEL_PARALLEL_GROUP is not None, \
'model parallel group is not initialized'
return _MODEL_PARALLEL_GROUP
def get_data_parallel_group():
"""Get the data parallel group the caller rank belongs to."""
assert _DATA_PARALLEL_GROUP is not None, \
'data parallel group is not initialized'
return _DATA_PARALLEL_GROUP
def get_model_parallel_world_size():
"""Return world size for the model parallel group."""
return torch.distributed.get_world_size(group=get_model_parallel_group())
def get_model_parallel_rank():
"""Return my rank for the model parallel group."""
return torch.distributed.get_rank(group=get_model_parallel_group())
def get_model_parallel_src_rank():
"""Calculate the global rank corresponding to a local rank zeor
in the model parallel group."""
global_rank = torch.distributed.get_rank()
local_world_size = get_model_parallel_world_size()
return (global_rank // local_world_size) * local_world_size
def get_data_parallel_world_size():
"""Return world size for the data parallel group."""
return torch.distributed.get_world_size(group=get_data_parallel_group())
def get_data_parallel_rank():
"""Return my rank for the data parallel group."""
return torch.distributed.get_rank(group=get_data_parallel_group())
def destroy_model_parallel():
"""Set the groups to none."""
global _MODEL_PARALLEL_GROUP
_MODEL_PARALLEL_GROUP = None
global _DATA_PARALLEL_GROUP
_DATA_PARALLEL_GROUP = None
# coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Parts of the code here are adapted from PyTorch
# repo: https://github.com/pytorch/pytorch
import math
import torch
import torch.nn.functional as F
import torch.nn.init as init
from torch.nn.parameter import Parameter
from torch.nn import LayerNorm
from .initialize import get_model_parallel_rank
from .initialize import get_model_parallel_world_size
from .mappings import copy_to_model_parallel_region
from .mappings import gather_from_model_parallel_region
from .mappings import reduce_from_model_parallel_region
from .mappings import scatter_to_model_parallel_region
from .random import get_cuda_rng_tracker
from .utils import divide
from .utils import split_tensor_along_last_dim
from .utils import VocabUtility
def _initialize_affine_weight(weight, output_size, input_size,
per_partition_size, partition_dim, init_method,
stride=1, return_master_weight=False):
"""Initialize affine weight for model parallel.
Build the master weight on all processes and scatter
the relevant chunk."""
# If we only use 1 process for model parallelism, bypass scatter.
world_size = get_model_parallel_world_size()
if world_size == 1:
init_method(weight)
if return_master_weight:
return weight
return None
# Initialize master weight
master_weight = torch.empty(output_size, input_size,
dtype=weight.dtype,
requires_grad=False)
init_method(master_weight)
# Split and copy
per_partition_per_stride_size = divide(per_partition_size, stride)
weight_list = torch.split(master_weight, per_partition_per_stride_size,
dim=partition_dim)
rank = get_model_parallel_rank()
my_weight_list = weight_list[rank::world_size]
with torch.no_grad():
torch.cat(my_weight_list, dim=partition_dim, out=weight)
if return_master_weight:
return master_weight
return None
class VocabParallelEmbedding(torch.nn.Module):
"""Embedding parallelized in the vocabulary dimension.
This is mainly adapted from torch.nn.Embedding and all the default
values are kept.
Arguments:
num_embeddings: vocabulary size.
embedding_dim: size of hidden state.
init_method: method to initialize weights.
"""
def __init__(self, num_embeddings, embedding_dim,
init_method=init.xavier_normal_):
super(VocabParallelEmbedding, self).__init__()
# Keep the input dimensions.
self.num_embeddings = num_embeddings
self.embedding_dim = embedding_dim
# Set the detauls for compatibility.
self.padding_idx = None
self.max_norm = None
self.norm_type = 2.
self.scale_grad_by_freq = False
self.sparse = False
self._weight = None
# Divide the weight matrix along the vocaburaly dimension.
self.vocab_start_index, self.vocab_end_index = \
VocabUtility.vocab_range_from_global_vocab_size(
self.num_embeddings, get_model_parallel_rank(),
get_model_parallel_world_size())
self.num_embeddings_per_partition = self.vocab_end_index - \
self.vocab_start_index
# Allocate weights.
self.weight = Parameter(torch.Tensor(self.num_embeddings_per_partition,
self.embedding_dim))
self.weight.model_parallel = True
# And initialize.
_initialize_affine_weight(
self.weight, self.num_embeddings, self.embedding_dim,
self.num_embeddings_per_partition, 0, init_method)
def forward(self, input_):
# Build the mask.
input_mask = (input_ < self.vocab_start_index) | \
(input_ >= self.vocab_end_index)
# Mask the input.
masked_input = input_.clone() - self.vocab_start_index
masked_input[input_mask] = 0
# Get the embeddings.
output_parallel = F.embedding(masked_input, self.weight,
self.padding_idx, self.max_norm,
self.norm_type, self.scale_grad_by_freq,
self.sparse)
# Mask the output embedding.
output_parallel[input_mask, :] = 0.0
# Reduce across all the model parallel GPUs.
output = reduce_from_model_parallel_region(output_parallel)
return output
class ParallelEmbedding(torch.nn.Module):
"""Embedding parallelized in the embedding dimension.
This is mainly adapted from torch.nn.Embedding and all the default
values are kept.
Arguments:
num_embeddings: vocabulary size.
embedding_dim: size of hidden state.
init_method: method to initialize weights.
"""
def __init__(self, num_embeddings, embedding_dim,
init_method=init.xavier_normal_,
keep_master_weight_for_test=False):
super(ParallelEmbedding, self).__init__()
# Keep the input dimensions.
self.num_embeddings = num_embeddings
self.embedding_dim = embedding_dim
# Set some detauls for compatibility.
self.padding_idx = None
self.max_norm = None
self.norm_type = 2.
self.scale_grad_by_freq = False
self.sparse = False
self._weight = None
# Divide the weight matrix along the embedding dimension.
world_size = get_model_parallel_world_size()
self.embedding_dim_per_partition = divide(self.embedding_dim,
world_size)
# Allocate weights.
self.weight = Parameter(torch.Tensor(self.num_embeddings,
self.embedding_dim_per_partition))
self.weight.model_parallel = True
# And initialize.
_initialize_affine_weight(
self.weight, self.num_embeddings, self.embedding_dim,
self.embedding_dim_per_partition, 1, init_method,
stride=1, return_master_weight=False)
def forward(self, input_):
input_parallel = copy_to_model_parallel_region(input_)
output_parallel = F.embedding(input_parallel, self.weight,
self.padding_idx, self.max_norm,
self.norm_type, self.scale_grad_by_freq,
self.sparse)
output = gather_from_model_parallel_region(output_parallel)
return output
class ColumnParallelLinear(torch.nn.Module):
"""Linear layer with column parallelism.
The linear layer is defined as Y = XA + b. A is parallelized along
its second dimension as A = [A_1, ..., A_p].
Arguments:
input_size: first dimension of matrix A.
output_size: second dimension of matrix A.
bias: If true, add bias
gather_output: If true, call all-gether on output and make Y avaiable
to all GPUs, otherwise, every GPU will have its output
which is Y_i = XA_i
init_method: method to initialize weights. Note that bias is always set
to zero.
stride: For the strided linear layers.
keep_master_weight_for_test: This was added for testing and should be
set to False. It returns the master weights
used for initialization.
"""
def __init__(self, input_size, output_size, bias=True, gather_output=True,
init_method=init.xavier_normal_, stride=1,
keep_master_weight_for_test=False):
super(ColumnParallelLinear, self).__init__()
# Keep input parameters
self.input_size = input_size
self.output_size = output_size
self.gather_output = gather_output
# Divide the weight matrix along the last dimension.
world_size = get_model_parallel_world_size()
self.output_size_per_partition = divide(output_size, world_size)
# Parameters.
# Note: torch.nn.functional.linear performs XA^T + b and as a result
# we allocate the transpose.
self.weight = Parameter(torch.Tensor(self.output_size_per_partition,
self.input_size))
self.weight.model_parallel = True
if bias:
self.bias = Parameter(torch.Tensor(self.output_size_per_partition))
self.bias.model_parallel = True
# Always initialize bias to zero.
with torch.no_grad():
self.bias.zero_()
else:
self.register_parameter('bias', None)
# Initialize weight.
self.master_weight = _initialize_affine_weight(
self.weight, self.output_size, self.input_size,
self.output_size_per_partition, 0, init_method,
stride=stride, return_master_weight=keep_master_weight_for_test)
def forward(self, input_):
# Set up backprop all-reduce.
input_parallel = copy_to_model_parallel_region(input_)
# Matrix multiply.
output_parallel = F.linear(input_parallel, self.weight, self.bias)
if self.gather_output:
# All-gather across the partitions.
output = gather_from_model_parallel_region(output_parallel)
else:
output = output_parallel
return output
class RowParallelLinear(torch.nn.Module):
"""Linear layer with row parallelism.
The linear layer is defined as Y = XA + b. A is parallelized along
its first dimension and X along its second dimension as:
- -
| A_1 |
| . |
A = | . | X = [X_1, ..., X_p]
| . |
| A_p |
- -
Arguments:
input_size: first dimension of matrix A.
output_size: second dimension of matrix A.
bias: If true, add bias. Note that bias is not parallelized.
input_is_parallel: If true, we assume that the input is already
split across the GPUs and we do not split
again.
init_method: method to initialize weights. Note that bias is always set
to zero.
stride: For the strided linear layers.
keep_master_weight_for_test: This was added for testing and should be
set to False. It returns the master weights
used for initialization.
"""
def __init__(self, input_size, output_size, bias=True,
input_is_parallel=False,
init_method=init.xavier_normal_, stride=1,
keep_master_weight_for_test=False):
super(RowParallelLinear, self).__init__()
# Keep input parameters
self.input_size = input_size
self.output_size = output_size
self.input_is_parallel = input_is_parallel
# Divide the weight matrix along the last dimension.
world_size = get_model_parallel_world_size()
self.input_size_per_partition = divide(input_size, world_size)
# Parameters.
# Note: torch.nn.functional.linear performs XA^T + b and as a result
# we allocate the transpose.
self.weight = Parameter(torch.Tensor(self.output_size,
self.input_size_per_partition))
self.weight.model_parallel = True
if bias:
self.bias = Parameter(torch.Tensor(self.output_size))
# Always initialize bias to zero.
with torch.no_grad():
self.bias.zero_()
else:
self.register_parameter('bias', None)
# Initialize weight.
self.master_weight = _initialize_affine_weight(
self.weight, self.output_size, self.input_size,
self.input_size_per_partition, 1, init_method,
stride=stride, return_master_weight=keep_master_weight_for_test)
def forward(self, input_):
# Set up backprop all-reduce.
if self.input_is_parallel:
input_parallel = input_
else:
input_parallel = scatter_to_model_parallel_region(input_)
# Matrix multiply.
output_parallel = F.linear(input_parallel, self.weight)
# All-reduce across all the partitions.
output_ = reduce_from_model_parallel_region(output_parallel)
if self.bias is not None:
output = output_ + self.bias
else:
output = output_
return output
# coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from .initialize import get_model_parallel_group
from .utils import split_tensor_along_last_dim
def _reduce(input_):
"""All-reduce the the input tensor across model parallel group."""
group = get_model_parallel_group()
# Bypass the function if we are using only 1 GPU.
if torch.distributed.get_world_size(group=group) == 1:
return input_
# All-reduce.
torch.distributed.all_reduce(input_, group=group)
return input_
def _split(input_):
"""Split the tensor along its last dimension and keep the
corresponding slice."""
group = get_model_parallel_group()
# Bypass the function if we are using only 1 GPU.
if torch.distributed.get_world_size(group=group) == 1:
return input_
# Split along last dimension.
world_size = torch.distributed.get_world_size(group=group)
input_list = split_tensor_along_last_dim(input_, world_size)
# Note: torch.split does not create contiguous tensors by default.
rank = torch.distributed.get_rank(group=group)
output = input_list[rank].contiguous()
return output
def _gather(input_):
"""Gather tensors and concatinate along the last dimension."""
group = get_model_parallel_group()
# Bypass the function if we are using only 1 GPU.
if torch.distributed.get_world_size(group=group) == 1:
return input_
# Size and dimension.
last_dim = input_.dim() - 1
rank = torch.distributed.get_rank(group=group)
world_size = torch.distributed.get_world_size(group=group)
tensor_list = [torch.empty_like(input_) for _ in range(world_size)]
tensor_list[rank] = input_
torch.distributed.all_gather(tensor_list, input_, group=group)
# Note: torch.cat already creates a contiguous tensor.
output = torch.cat(tensor_list, dim=last_dim).contiguous()
return output
class _CopyToModelParallelRegion(torch.autograd.Function):
"""Pass the input to the model parallel region."""
@staticmethod
def forward(ctx, input_):
return input_
@staticmethod
def backward(ctx, grad_output):
return _reduce(grad_output)
class _ReduceFromModelParallelRegion(torch.autograd.Function):
"""All-redcue the input from the model parallel region."""
@staticmethod
def forward(ctx, input_):
return _reduce(input_)
@staticmethod
def backward(ctx, grad_output):
return grad_output
class _ScatterToModelParallelRegion(torch.autograd.Function):
"""Split the input and keep only the corresponding chuck to the rank."""
@staticmethod
def forward(ctx, input_):
return _split(input_)
@staticmethod
def backward(ctx, grad_output):
return _gather(grad_output)
class _GatherFromModelParallelRegion(torch.autograd.Function):
"""Gather the input from model parallel region and concatinate."""
@staticmethod
def forward(ctx, input_):
return _gather(input_)
@staticmethod
def backward(ctx, grad_output):
return _split(grad_output)
# -----------------
# Helper functions.
# -----------------
def copy_to_model_parallel_region(input_):
return _CopyToModelParallelRegion.apply(input_)
def reduce_from_model_parallel_region(input_):
return _ReduceFromModelParallelRegion.apply(input_)
def scatter_to_model_parallel_region(input_):
return _ScatterToModelParallelRegion.apply(input_)
def gather_from_model_parallel_region(input_):
return _GatherFromModelParallelRegion.apply(input_)
# coding=utf-8
#Modified by Samyam Rajbhandari
#Used to partition the activations stored for backward propagation
#Therefore reduces the memory consumption
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Parts of the code here are adapted from PyTorch
# repo: https://github.com/pytorch/pytorch
import contextlib
import torch.distributed as dist
import torch
from torch import _C
from torch.cuda import _lazy_call, device as device_ctx_manager
#from torch.utils.checkpoint import detach_variable
import torch.distributed as dist
PARTITION_ACTIVATIONS = False
PA_CORRECTNESS_TEST= False
def see_memory_usage(message, force=False):
if not force:
return
dist.barrier()
if dist.get_rank() == 0:
print(message)
print("Memory Allocated ", torch.cuda.memory_allocated()/(1024*1024*1024), "GigaBytes")
print("Max Memory Allocated ", torch.cuda.max_memory_allocated()/(1024*1024*1024), "GigaBytes")
print("Cache Allocated ", torch.cuda.memory_cached()/(1024*1024*1024), "GigaBytes")
print("Max cache Allocated ", torch.cuda.max_memory_cached()/(1024*1024*1024), "GigaBytes")
print(" ")
#input("Press Any Key To Continue ..")
from .initialize import get_data_parallel_rank
from .initialize import get_model_parallel_rank
from .initialize import get_model_parallel_world_size
from .initialize import get_model_parallel_group
mp_rank = None #get_model_parallel_rank()
mp_size = None #get_model_parallel_world_size()
mp_group = None #get_model_parallel_group()
# Default name for the model parallel rng tracker.
_MODEL_PARALLEL_RNG_TRACKER_NAME = 'model-parallel-rng'
transport_stream = None
cuda_device=None
def detach_variable(inputs, device=None):
if isinstance(inputs, tuple):
out = []
for inp in inputs:
if not isinstance(inp, torch.Tensor):
out.append(inp)
continue
requires_grad = inp.requires_grad
if device is not None:
x = inp.to(device=device)
else:
x = inp
x = x.detach()
x.requires_grad = requires_grad
out.append(x)
return tuple(out)
else:
raise RuntimeError(
"Only tuple of tensors is supported. Got Unsupported input type: ", type(inputs).__name__)
def _set_cuda_rng_state(new_state, device=-1):
"""Sets the random number generator state of the current GPU.
Argumentss:
new_state (torch.ByteTensor): The desired state
This function is adapted from PyTorch repo (torch.cuda.set_rng_state)
with a single change: the input state is not cloned. Cloning caused
major performance issues for +4 GPU cases.
"""
if hasattr(_C, '_cuda_setRNGState') and callable(_C._cuda_setRNGState):
# older PyTorch
def cb():
with device_ctx_manager(device):
_C._cuda_setRNGState(new_state)
else:
# newer PyTorch
if device == -1:
device = torch.device('cuda')
elif isinstance(device, str):
device = torch.device(device)
elif isinstance(device, int):
device = torch.device('cuda', device)
def cb():
idx = device.index
if idx is None:
idx = torch.cuda.current_device()
default_generator = torch.cuda.default_generators[idx]
default_generator.set_state(new_state)
_lazy_call(cb)
class CudaRNGStatesTracker:
"""Tracker for the cuda RNG states.
Using the `add` method, a cuda rng state is initialized based on
the input `seed` and is assigned to `name`. Later, by forking the
rng state, we can perform operations and return to our starting
cuda state.
"""
def __init__(self):
# Map from a string name to the cuda rng state.
self.states_ = {}
# Seeds are just for book keeping and ensure no seed is set twice.
self.seeds_ = set()
def reset(self):
"""Set to the initial state (no tracker)."""
self.states_ = {}
self.seeds_ = set()
def get_states(self):
"""Get rng states. Copy the dictionary so we have direct
pointers to the states, not just a pointer to the dictionary."""
states = {}
for name in self.states_:
states[name] = self.states_[name]
return states
def set_states(self, states):
"""Set the rng states. For efficiency purposes, we do not check
the size of seed for compatibility."""
self.states_ = states
def add(self, name, seed):
"""Track the rng state."""
# Check seed is not already used.
if seed in self.seeds_:
raise Exception('seed {} already exists'.format(seed))
self.seeds_.add(seed)
# Check that state is not already defined.
if name in self.states_:
raise Exception('cuda rng state {} already exists'.format(name))
# Get the current rng state.
orig_rng_state = torch.cuda.get_rng_state()
# Set the new state and store it.
torch.cuda.manual_seed(seed)
self.states_[name] = torch.cuda.get_rng_state()
# Reset rng state to what it was.
_set_cuda_rng_state(orig_rng_state)
@contextlib.contextmanager
def fork(self, name=_MODEL_PARALLEL_RNG_TRACKER_NAME):
"""Fork the cuda rng state, perform operations, and exit with
the original state."""
# Check if we have added the state
if name not in self.states_:
raise Exception('cuda rng state {} is not added'.format(name))
# Store current rng state.
orig_cuda_rng_state = torch.cuda.get_rng_state()
# Set rng state to the desired one
_set_cuda_rng_state(self.states_[name])
# Do the stuff we wanted to do.
try:
yield
finally:
# Update the current rng state for later use.
self.states_[name] = torch.cuda.get_rng_state()
# And set the state to the original state we started with.
_set_cuda_rng_state(orig_cuda_rng_state)
# RNG tracker object.
_CUDA_RNG_STATE_TRACKER = CudaRNGStatesTracker()
def get_cuda_rng_tracker():
"""Get cuda rng tracker."""
return _CUDA_RNG_STATE_TRACKER
def model_parallel_cuda_manual_seed(seed):
"""Initialize model parallel cuda seed.
This function should be called after the model parallel is
initialized. Also, no torch.cuda.manual_seed should be called
after this function. Basically, this is replacement for that
function.
Two set of RNG states are tracked:
default state: This is for data parallelism and is the same among a
set of model parallel GPUs but different across
different model paralle groups. This is used for
example for dropout in the non-model-parallel regions.
model-parallel state: This state is different among a set of model
parallel GPUs, but the same across data parallel
groups. This is used for example for dropout in
model parallel regions.
"""
# 2718 is just for fun and any POSITIVE value will work.
offset = seed + 2718
model_parallel_seed = offset + get_model_parallel_rank()
# Data parallel gets the original sedd.
data_parallel_seed = seed
if torch.distributed.get_rank() == 0:
print('> initializing model parallel cuda seeds on global rank {}, '
'model parallel rank {}, and data parallel rank {} with '
'model parallel seed: {} and data parallel seed: {}'.format(
torch.distributed.get_rank(), get_model_parallel_rank(),
get_data_parallel_rank(), model_parallel_seed,
data_parallel_seed), flush=True)
_CUDA_RNG_STATE_TRACKER.reset()
# Set the default state.
torch.cuda.manual_seed(data_parallel_seed)
# and model parallel state.
_CUDA_RNG_STATE_TRACKER.add(_MODEL_PARALLEL_RNG_TRACKER_NAME,
model_parallel_seed)
def get_partition_start(item):
global mp_rank, mp_size, mp_group
partition_size = get_partition_size(item)
start = partition_size * mp_rank
return int(start)
def get_partition_size(item):
global mp_rank, mp_size, mp_group
size = item.numel()
partition_size = size/mp_size
return int(partition_size)
def get_full_inputs(tensors):
inputs=[]
for i in range(int(len(tensors)/2)-1):
item = tensors[2 * i]
size = tensors[2* i + 1]
partition_size = item.numel()
tensor_size = partition_size * mp_size
flat_tensor = torch.zeros([tensor_size], dtype=item.dtype, device=item.device)
partitions=[]
for i in range(mp_size):
part_i = flat_tensor.narrow(0, partition_size * i , partition_size)
if i == mp_rank:
part_i.copy_(item)
partitions.append(part_i)
dist.all_gather(partitions,partitions[mp_rank], group=mp_group)
input_tensor = flat_tensor.view(list(size.numpy()))
item.data=input_tensor.data
inputs.append(item)
inputs.append(tensors[-2])
return tuple(inputs)
class CheckpointFunction(torch.autograd.Function):
"""This function is adapted from torch.utils.checkpoint with
two main changes:
1) torch.cuda.set_rng_state is replaced with `_set_cuda_rng_state`
2) the states in the model parallel tracker are also properly
tracked/set/reset.
"""
@staticmethod
def forward(ctx, run_function, *args):
ctx.run_function = run_function
global mp_rank, mp_size, mp_group
if mp_rank is None:
mp_rank = get_model_parallel_rank()
mp_size = get_model_parallel_world_size()
mp_group = get_model_parallel_group()
global cuda_device, transport_stream, PARTITION_ACTIVATIONS
if cuda_device is None:
if dist.get_rank() == 0:
print(f"Partition Activations {PARTITION_ACTIVATIONS} and Correctness Check {PA_CORRECTNESS_TEST}")
cuda_device = torch.cuda.current_device()
#The transport stream is used to overlap the allgather communication for the activations
#with the computation in the backward pass
transport_stream = torch.cuda.Stream(device=cuda_device)
if PARTITION_ACTIVATIONS:
inputs = [item.detach().contiguous().view(-1).narrow(0, get_partition_start(item), get_partition_size(item)).clone() for item in args[:-1]]
inputs.append(args[-1])
#just in case something funky is happening such as reuse of inputs
inputs_cuda = [item.to(cuda_device) for item in args]
# Copy the rng states.
ctx.fwd_cpu_rng_state = torch.get_rng_state()
ctx.fwd_cuda_rng_state = torch.cuda.get_rng_state()
ctx.fwd_cuda_rng_state_tracker = get_cuda_rng_tracker().get_states()
#ctx.save_for_backward(*args)
with torch.no_grad():
outputs = run_function(*inputs_cuda)
del inputs_cuda
if PARTITION_ACTIVATIONS:
new_args = []
for arg, inp in zip(args,inputs):
size= torch.tensor(arg.size())
arg.data = inp.data
new_args.append(arg)
new_args.append(size)
ctx.save_for_backward(*new_args)
else:
ctx.save_for_backward(*args)
return outputs
@staticmethod
def backward(ctx, *args):
if not torch.autograd._is_checkpoint_valid():
raise RuntimeError("Checkpointing is not compatible with .grad(), "
"please use .backward() if possible")
global cuda_device, transport_stream, PARTITION_ACTIVATIONS
if PARTITION_ACTIVATIONS:
with torch.cuda.stream(transport_stream):
inputs = get_full_inputs(ctx.saved_tensors)
detached_inputs = detach_variable(inputs)
else:
inputs = ctx.saved_tensors
detached_inputs = detach_variable(inputs)
# Store the current states.
bwd_cpu_rng_state = torch.get_rng_state()
bwd_cuda_rng_state = torch.cuda.get_rng_state()
bwd_cuda_rng_state_tracker = get_cuda_rng_tracker().get_states()
# Set the states to what it used to be before the forward pass.
torch.set_rng_state(ctx.fwd_cpu_rng_state)
_set_cuda_rng_state(ctx.fwd_cuda_rng_state)
get_cuda_rng_tracker().set_states(ctx.fwd_cuda_rng_state_tracker)
if PARTITION_ACTIVATIONS:
current_stream=torch.cuda.current_stream()
current_stream.wait_stream(transport_stream)
with torch.enable_grad():
outputs = ctx.run_function(*detached_inputs)
# Set the states back to what it was at the start of this function.
torch.set_rng_state(bwd_cpu_rng_state)
_set_cuda_rng_state(bwd_cuda_rng_state)
get_cuda_rng_tracker().set_states(bwd_cuda_rng_state_tracker)
if isinstance(outputs, torch.Tensor):
outputs = (outputs,)
torch.autograd.backward(outputs, args)
return (None,) + tuple(inp.grad for inp in detached_inputs)
def checkpoint(function, *args):
"""Checkpoint a model or part of the model.
This has been directly copied from torch.utils.checkpoint."""
return CheckpointFunction.apply(function, *args)
def partition_activations_in_checkpoint(partition_activation):
global PARTITION_ACTIVATIONS
PARTITION_ACTIVATIONS=partition_activation
if dist.get_rank() == 0:
print(f"**************Partition Activations {PARTITION_ACTIVATIONS}************")
# coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import os
import random
import numpy
import torch
import mpu
class IdentityLayer(torch.nn.Module):
def __init__(self, size, scale=1.0):
super(IdentityLayer, self).__init__()
self.weight = torch.nn.Parameter(scale * torch.randn(size))
def forward(self):
return self.weight
def set_random_seed(seed):
"""Set random seed for reproducability."""
random.seed(seed)
numpy.random.seed(seed)
torch.manual_seed(seed)
mpu.model_parallel_cuda_manual_seed(seed)
def initialize_distributed(backend='nccl'):
"""Initialize torch.distributed."""
# Get local rank in case it is provided.
parser = argparse.ArgumentParser()
parser.add_argument('--local_rank', type=int, default=None,
help='local rank passed from distributed launcher')
args = parser.parse_args()
local_rank = args.local_rank
# Get rank and world size.
rank = int(os.getenv('RANK', '0'))
world_size = int(os.getenv("WORLD_SIZE", '1'))
print('> initializing torch.distributed with local rank: {}, '
'rank: {}, world size: {}'.format(local_rank, rank, world_size))
# Set the device id.
device = rank % torch.cuda.device_count()
if local_rank is not None:
device = local_rank
torch.cuda.set_device(device)
# Call the init process.
init_method = 'tcp://'
master_ip = os.getenv('MASTER_ADDR', 'localhost')
master_port = os.getenv('MASTER_PORT', '6000')
init_method += master_ip + ':' + master_port
torch.distributed.init_process_group(
backend=backend,
world_size=world_size,
rank=rank,
init_method=init_method)
def print_separator(message):
torch.distributed.barrier()
filler_len = (78 - len(message)) // 2
filler = '-' * filler_len
string = '\n' + filler + ' {} '.format(message) + filler
if torch.distributed.get_rank() == 0:
print(string, flush=True)
torch.distributed.barrier()
# coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import random
import sys
sys.path.append("../..")
import torch
import torch.nn.functional as F
import mpu
from mpu.cross_entropy import vocab_parallel_cross_entropy
from commons import initialize_distributed
from commons import print_separator
from commons import IdentityLayer
from commons import set_random_seed
def torch_cross_entropy(batch_size, seq_length, vocab_size,
logits_scale, seed):
set_random_seed(seed)
identity = IdentityLayer((batch_size, seq_length, vocab_size),
scale=logits_scale).cuda()
logits = identity()
target = torch.cuda.LongTensor(
size=(batch_size, seq_length)).random_(0, vocab_size)
loss = F.cross_entropy(logits.view(-1, logits.size()[-1]),
target.view(-1),
reduction='none').view_as(target).mean()
loss.backward()
return loss, identity.weight.grad
def mpu_cross_entropy(batch_size, seq_length, vocab_size,
logits_scale, seed):
set_random_seed(seed)
identity = IdentityLayer((batch_size, seq_length, vocab_size),
scale=logits_scale).cuda()
logits = identity()
logits_parallel = mpu.scatter_to_model_parallel_region(logits)
target = torch.cuda.LongTensor(
size=(batch_size, seq_length)).random_(0, vocab_size)
loss = vocab_parallel_cross_entropy(logits_parallel, target).mean()
loss.backward()
return loss, identity.weight.grad
def test_cross_entropy(model_parallel_size):
if torch.distributed.get_rank() == 0:
print('> testing cross entropy with model parallel size {} ...'.
format(model_parallel_size))
mpu.initialize_model_parallel(model_parallel_size)
model_parallel_size = mpu.get_model_parallel_world_size()
batch_size = 13
seq_length = 17
vocab_size_per_partition = 11
logits_scale = 1000.0
vocab_size = vocab_size_per_partition * model_parallel_size
seed = 1234
loss_torch, grad_torch = torch_cross_entropy(batch_size, seq_length,
vocab_size, logits_scale,
seed)
loss_mpu, grad_mpu = mpu_cross_entropy(batch_size, seq_length,
vocab_size, logits_scale,
seed)
error = loss_torch.sub_(loss_mpu).abs().max()
print(' max error in loss on global rank {}: {}'.format(
torch.distributed.get_rank(), error))
assert error < 1.0e-6
error = grad_torch.sub_(grad_mpu).abs().max()
print(' max error in grad on global rank {}: {}'.format(
torch.distributed.get_rank(), error))
assert error < 1.0e-6
# Reset groups
mpu.destroy_model_parallel()
torch.distributed.barrier()
if torch.distributed.get_rank() == 0:
print('>> passed the test :-)')
if __name__ == '__main__':
initialize_distributed()
world_size = torch.distributed.get_world_size()
model_parallel_size = 1
while model_parallel_size <= world_size:
print_separator('test cross entropy')
test_cross_entropy(model_parallel_size)
model_parallel_size *= 2
# coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import functools
import operator
import sys
sys.path.append("../..")
import torch
import mpu
from mpu import data as data_utils
from commons import initialize_distributed
from commons import print_separator
def test_boradcast_data(model_parallel_size):
if torch.distributed.get_rank() == 0:
print('> testing boradcast_data with model parallel size {} ...'.
format(model_parallel_size))
mpu.initialize_model_parallel(model_parallel_size)
torch.manual_seed(1234 + mpu.get_data_parallel_rank())
model_parallel_size = mpu.get_model_parallel_world_size()
key_size_t = {'key1': [7, 11],
'key2': [8, 2, 1],
'key3': [13],
'key4': [5, 1, 2],
'key5': [5, 12]}
keys = list(key_size_t.keys())
data = {}
data_t = {}
for key in key_size_t:
data[key] = torch.LongTensor(size=key_size_t[key]).random_(0, 1000)
data_t[key] = data[key].clone()
data['keyX'] = torch.FloatTensor(size=(5, )).random_(0, 1000)
data_t['keyX'] = data['keyX'].clone()
if mpu.get_model_parallel_rank() != 0:
data = None
data_utils._check_data_types(keys, data_t, torch.int64)
key_size, key_numel, \
total_numel = data_utils._build_key_size_numel_dictionaries(keys, data)
for key in keys:
assert key_size[key] == key_size_t[key]
total_numel_t = 0
for key in keys:
target_size = functools.reduce(operator.mul, key_size_t[key], 1)
assert key_numel[key] == target_size
total_numel_t += target_size
assert total_numel == total_numel_t
data_b = data_utils.broadcast_data(keys, data, torch.int64)
for key in keys:
tensor = data_t[key].cuda()
assert data_b[key].sub(tensor).abs().max() == 0
# Reset groups
mpu.destroy_model_parallel()
torch.distributed.barrier()
if torch.distributed.get_rank() == 0:
print('>> passed the test :-)')
if __name__ == '__main__':
initialize_distributed()
world_size = torch.distributed.get_world_size()
model_parallel_size = 1
while model_parallel_size <= world_size:
print_separator('test test boradcast data')
test_boradcast_data(model_parallel_size)
model_parallel_size *= 2
# coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
sys.path.append("../..")
import torch
import mpu
from commons import initialize_distributed
from commons import print_separator
def test_initialize_model_parallel(model_parallel_size):
if torch.distributed.get_rank() == 0:
print('> testing initialize_model_parallel with size {} ...'.format(
model_parallel_size))
model_parallel_size_ = min(model_parallel_size,
torch.distributed.get_world_size())
assert not mpu.model_parallel_is_initialized()
mpu.initialize_model_parallel(model_parallel_size_)
assert mpu.model_parallel_is_initialized()
# Checks.
def check(group, world_size, rank):
assert world_size == torch.distributed.get_world_size(group=group)
assert rank == torch.distributed.get_rank(group=group)
# Model parallel.
world_size = model_parallel_size_
rank = torch.distributed.get_rank() % model_parallel_size_
assert world_size == mpu.get_model_parallel_world_size()
assert rank == mpu.get_model_parallel_rank()
check(mpu.get_model_parallel_group(), world_size, rank)
# Data parallel.
world_size = torch.distributed.get_world_size() // model_parallel_size_
rank = torch.distributed.get_rank() // model_parallel_size
assert world_size == mpu.get_data_parallel_world_size()
assert rank == mpu.get_data_parallel_rank()
check(mpu.get_data_parallel_group(), world_size, rank)
# Reset groups
mpu.destroy_model_parallel()
torch.distributed.barrier()
if torch.distributed.get_rank() == 0:
print('>> passed the test :-)')
def test_get_model_parallel_src_rank(model_parallel_size_):
if torch.distributed.get_rank() == 0:
print('> testing get_model_parallel_src_rank with size {} ...'.format(
model_parallel_size_))
model_parallel_size = min(model_parallel_size_,
torch.distributed.get_world_size())
assert not mpu.model_parallel_is_initialized()
mpu.initialize_model_parallel(model_parallel_size)
assert mpu.model_parallel_is_initialized()
# Checks
src_rank = torch.distributed.get_rank() - mpu.get_model_parallel_rank()
assert mpu.get_model_parallel_src_rank() == src_rank
# Reset groups
mpu.destroy_model_parallel()
torch.distributed.barrier()
if torch.distributed.get_rank() == 0:
print('>> passed the test :-)')
if __name__ == '__main__':
initialize_distributed()
world_size = torch.distributed.get_world_size()
model_parallel_size = 1
while model_parallel_size <= world_size:
print_separator('test initialize model parallel')
test_initialize_model_parallel(model_parallel_size)
print_separator('test model parallel source rank')
test_get_model_parallel_src_rank(model_parallel_size)
model_parallel_size *= 2
# coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import random
import sys
sys.path.append("../..")
import torch
import torch.nn.init as init
from torch.nn.parameter import Parameter
import mpu
from commons import initialize_distributed
from commons import print_separator
from commons import set_random_seed
from mpu import layers
def test_parallel_embedding(model_parallel_size):
if torch.distributed.get_rank() == 0:
print('> testing parallel embedding with model parallel size {} ...'.
format(model_parallel_size))
mpu.initialize_model_parallel(model_parallel_size)
model_parallel_size = mpu.get_model_parallel_world_size()
batch_size = 17
seq_length = 23
vocab_size = 48
hidden_size = 16
seed = 1236
set_random_seed(123)
input_data = torch.LongTensor(
size=(batch_size,seq_length)).random_(0, vocab_size).cuda()
loss_weight = torch.randn([batch_size, seq_length, hidden_size]).cuda()
set_random_seed(seed)
embedding_original = torch.nn.Embedding(vocab_size, hidden_size).cuda()
output = embedding_original(input_data)
loss_original = torch.mul(output, loss_weight).sum()
loss_original.backward()
set_random_seed(seed)
embedding_parallel = layers.ParallelEmbedding(
vocab_size, hidden_size, init_method=init.normal_).cuda()
output = embedding_parallel(input_data)
loss_parallel = torch.mul(output, loss_weight).sum()
loss_parallel.backward()
set_random_seed(seed)
embedding_vocab_parallel = layers.VocabParallelEmbedding(
vocab_size, hidden_size, init_method=init.normal_).cuda()
output = embedding_vocab_parallel(input_data)
loss_vocab_parallel = torch.mul(output, loss_weight).sum()
loss_vocab_parallel.backward()
torch.distributed.barrier()
error = loss_parallel.sub(loss_original).abs()
print(' error in loss (parallel) on global rank {}: {}'.format(
torch.distributed.get_rank(), error))
assert error < 1.0e-12, 'error: {}'.format(error)
torch.distributed.barrier()
error = loss_vocab_parallel.sub(loss_original).abs()
print(' error in loss (vocab parallel) on global rank {}: {}'.format(
torch.distributed.get_rank(), error))
assert error < 1.0e-12, 'error: {}'.format(error)
weight_grad_orig = torch.split(embedding_original.weight.grad,
hidden_size // model_parallel_size,
1)[mpu.get_model_parallel_rank()]
error = embedding_parallel.weight.grad.sub(weight_grad_orig).abs().max()
print(' error in grad (parallel) on global rank {}: {}'.format(
torch.distributed.get_rank(), error))
assert error < 1.0e-12, 'error: {}'.format(error)
weight_grad_orig = torch.split(embedding_original.weight.grad,
vocab_size // model_parallel_size,
0)[mpu.get_model_parallel_rank()]
error = embedding_vocab_parallel.weight.grad.sub(
weight_grad_orig).abs().max()
print(' error in grad (vocab parallel) on global rank {}: {}'.format(
torch.distributed.get_rank(), error))
assert error < 1.0e-12, 'error: {}'.format(error)
# Reset groups
mpu.destroy_model_parallel()
torch.distributed.barrier()
if torch.distributed.get_rank() == 0:
print('>> passed the test :-)')
def test_initialize_affine_weight(model_parallel_size):
mpu.initialize_model_parallel(model_parallel_size)
if torch.distributed.get_rank() == 0:
print('> testing initialize_affine_weight with model parallel '
'size: {}'.format(model_parallel_size))
model_parallel_size = mpu.get_model_parallel_world_size()
seed = 12345
input_size_coeff = 13
input_size = input_size_coeff * model_parallel_size
output_size_coeff = 17
output_size = output_size_coeff * model_parallel_size
# ---------------
# Column parallel
# ---------------
weight = torch.empty(output_size_coeff, input_size)
set_random_seed(seed)
layers._initialize_affine_weight(weight, output_size, input_size,
output_size_coeff, 0,
torch.nn.init.normal_)
# Target.
set_random_seed(seed)
master_weight = torch.empty(output_size, input_size)
torch.nn.init.normal_(master_weight)
rank = mpu.get_model_parallel_rank()
my_weight = torch.split(master_weight, output_size_coeff,
dim=0)[rank].contiguous().clone()
# Compare.
error = weight.sub(my_weight).abs().max()
torch.distributed.barrier()
print(' column parallel max error (should be zero) on global rank '
'{}: {}'.format(torch.distributed.get_rank(), error))
assert error < 1.0e-6
# ------------
# Row parallel
# ------------
weight = torch.empty(output_size, input_size_coeff)
set_random_seed(seed)
mpu.layers._initialize_affine_weight(weight, output_size, input_size,
input_size_coeff, 1,
torch.nn.init.normal_)
# Target.
set_random_seed(seed)
master_weight = torch.empty(output_size, input_size)
torch.nn.init.normal_(master_weight)
rank = mpu.get_model_parallel_rank()
my_weight = torch.split(master_weight, input_size_coeff,
dim=1)[rank].contiguous().clone()
# Compare.
error = weight.sub(my_weight).abs().max()
torch.distributed.barrier()
print(' row parallel max error (should be zero) on global rank '
'{}: {}'.format(torch.distributed.get_rank(), error))
assert error < 1.0e-6
# Reset groups
mpu.destroy_model_parallel()
torch.distributed.barrier()
if torch.distributed.get_rank() == 0:
print(' >> passed the test :-)')
class IdentityLayer2D(torch.nn.Module):
def __init__(self, m , n):
super(IdentityLayer2D, self).__init__()
self.weight = Parameter(torch.Tensor(m, n))
torch.nn.init.xavier_normal_(self.weight)
def forward(self):
return self.weight
def test_column_parallel_linear(model_parallel_size):
mpu.initialize_model_parallel(model_parallel_size)
if torch.distributed.get_rank() == 0:
print('> testing ColumnParallelLinear with model parallel '
'size: {}'.format(model_parallel_size))
model_parallel_size = mpu.get_model_parallel_world_size()
seed = 12345
set_random_seed(seed)
input_size_coeff = 13
input_size = input_size_coeff * model_parallel_size
output_size_coeff = 17
output_size = output_size_coeff * model_parallel_size
batch_size = 7
# Network
identity_layer = IdentityLayer2D(batch_size, input_size).cuda()
linear_layer = mpu.ColumnParallelLinear(
input_size, output_size, keep_master_weight_for_test=True).cuda()
loss_weight = torch.randn([batch_size, output_size]).cuda()
# Forward
input_ = identity_layer()
output = linear_layer(input_)
loss = torch.mul(output, loss_weight).sum()
# Backward
loss.backward()
# Values.
dLdY = loss_weight
X = identity_layer.weight
A = linear_layer.master_weight.cuda()
dLdA = torch.matmul(dLdY.t(), X)
dLdb = torch.matmul(torch.ones(batch_size, 1).cuda().t(), dLdY).view(-1)
dLdX = torch.matmul(dLdY, A)
rank = mpu.get_model_parallel_rank()
my_dLdA = torch.split(dLdA, output_size_coeff,
dim=0)[rank].contiguous().clone()
error = my_dLdA.sub(linear_layer.weight.grad).abs().max()
torch.distributed.barrier()
print(' error in dLdA on global rank {}: {}'.format(
torch.distributed.get_rank(), error))
assert error < 1.0e-6
my_dLdb = torch.split(dLdb, output_size_coeff,
dim=0)[rank].contiguous().clone()
error = my_dLdb.sub(linear_layer.bias.grad).abs().max()
torch.distributed.barrier()
print(' error in dLdb on global rank {}: {}'.format(
torch.distributed.get_rank(), error))
assert error < 1.0e-6
error = dLdX.sub(identity_layer.weight.grad).abs().max()
torch.distributed.barrier()
print(' error in dLdX on global rank {}: {}'.format(
torch.distributed.get_rank(), error))
assert error < 1.0e-6
# Reset groups
mpu.destroy_model_parallel()
torch.distributed.barrier()
if torch.distributed.get_rank() == 0:
print(' >> passed the test :-)')
def test_row_parallel_linear(model_parallel_size):
mpu.initialize_model_parallel(model_parallel_size)
if torch.distributed.get_rank() == 0:
print('> testing RowParallelLinear with model parallel '
'size: {}'.format(model_parallel_size))
model_parallel_size = mpu.get_model_parallel_world_size()
seed = 12345
set_random_seed(seed)
input_size_coeff = 13
input_size = input_size_coeff * model_parallel_size
output_size_coeff = 17
output_size = output_size_coeff * model_parallel_size
batch_size = 7
# Network
identity_layer = IdentityLayer2D(batch_size, input_size).cuda()
linear_layer = mpu.RowParallelLinear(
input_size, output_size, keep_master_weight_for_test=True).cuda()
loss_weight = torch.randn([batch_size, output_size]).cuda()
# Forward
input_ = identity_layer()
output = linear_layer(input_)
loss = torch.mul(output, loss_weight).sum()
# Backward
loss.backward()
# Values.
dLdY = loss_weight
X = identity_layer.weight
A = linear_layer.master_weight.cuda()
dLdA = torch.matmul(dLdY.t(), X)
dLdb = torch.matmul(torch.ones(batch_size, 1).cuda().t(), dLdY).view(-1)
dLdX = torch.matmul(dLdY, A)
rank = mpu.get_model_parallel_rank()
my_dLdA = torch.split(dLdA, input_size_coeff,
dim=1)[rank].contiguous().clone()
error = my_dLdA.sub(linear_layer.weight.grad).abs().max()
torch.distributed.barrier()
print(' error in dLdA on global rank {}: {}'.format(
torch.distributed.get_rank(), error))
assert error < 1.0e-6
error = dLdb.sub(linear_layer.bias.grad).abs().max()
torch.distributed.barrier()
print(' error in dLdb on global rank {}: {}'.format(
torch.distributed.get_rank(), error))
assert error < 1.0e-6
error = dLdX.sub(identity_layer.weight.grad).abs().max()
torch.distributed.barrier()
print(' error in dLdX on global rank {}: {}'.format(
torch.distributed.get_rank(), error))
assert error < 1.0e-6
# Reset groups
mpu.destroy_model_parallel()
torch.distributed.barrier()
if torch.distributed.get_rank() == 0:
print(' >> passed the test :-)')
class IdentityLayer3D(torch.nn.Module):
def __init__(self, m , n, k):
super(IdentityLayer3D, self).__init__()
self.weight = Parameter(torch.Tensor(m, n, k))
torch.nn.init.xavier_normal_(self.weight)
def forward(self):
return self.weight
def parallel_self_attention(model_parallel_size, num_att_heads_per_partition,
hidden_size_per_att_head, dropout_prob, batch_size,
sequence_length):
mpu.initialize_model_parallel(model_parallel_size)
model_parallel_size = mpu.get_model_parallel_world_size()
seed = 12345
set_random_seed(seed)
num_att_heads = num_att_heads_per_partition * \
torch.distributed.get_world_size()
hidden_size = hidden_size_per_att_head * num_att_heads
# Network
identity_layer = IdentityLayer3D(batch_size, sequence_length,
hidden_size).cuda()
attention_layer = mpu.BertParallelSelfAttention(hidden_size, num_att_heads,
dropout_prob).cuda()
loss_weight = torch.randn([batch_size, sequence_length, hidden_size]).cuda()
attention_mask = torch.randn([batch_size, 1, 1, sequence_length]).cuda()
# Forward
input_ = identity_layer()
output = attention_layer(input_, attention_mask)
loss = torch.mul(output, loss_weight).sum()
# Backward
loss.backward()
rank = mpu.get_model_parallel_rank()
mpu.destroy_model_parallel()
return rank, hidden_size, model_parallel_size, loss, \
attention_layer, identity_layer
def test_parallel_self_attention(model_parallel_size):
if torch.distributed.get_rank() == 0:
print('> testing ParallelSelfAttention with model parallel '
'size: {}'.format(model_parallel_size))
num_att_heads_per_partition = 3
hidden_size_per_att_head = 7
dropout_prob = 0.0 # has to be zero
batch_size = 5
sequence_length = 13
rank_1, hideen_size_1, model_parallel_size_1, loss_1, \
attention_layer_1, identity_layer_1 =parallel_self_attention(
1, num_att_heads_per_partition,
hidden_size_per_att_head, dropout_prob, batch_size, sequence_length)
rank, hidden_size, model_parallel_size, loss, \
attention_layer, identity_layer =parallel_self_attention(
model_parallel_size, num_att_heads_per_partition,
hidden_size_per_att_head, dropout_prob, batch_size, sequence_length)
assert hideen_size_1 == hidden_size
error = loss_1.sub(loss).abs().max()
torch.distributed.barrier()
print(' loss error on global rank {}: {}'.format(
torch.distributed.get_rank(), error))
assert error < 5.0e-6
my_lin_grad_list = torch.split(
attention_layer_1.query_key_value.weight.grad,
hidden_size // model_parallel_size, 0)[rank::model_parallel_size]
my_lin_grad = torch.cat(my_lin_grad_list, dim=0)
error = my_lin_grad.sub(
attention_layer.query_key_value.weight.grad).abs().max()
torch.distributed.barrier()
print(' weight gradient error on global rank {}: {}'.format(
torch.distributed.get_rank(), error))
assert error < 5.0e-6
error = identity_layer_1.weight.grad.sub(
identity_layer.weight.grad).abs().max()
torch.distributed.barrier()
print(' input gradient error on global rank {}: {}'.format(
torch.distributed.get_rank(), error))
assert error < 5.0e-6
torch.distributed.barrier()
if torch.distributed.get_rank() == 0:
print(' >> passed the test :-)')
def parallel_transformer(model_parallel_size, num_att_heads_per_partition,
hidden_size_per_att_head, batch_size, sequence_length):
mpu.initialize_model_parallel(model_parallel_size)
model_parallel_size = mpu.get_model_parallel_world_size()
seed = 12345
set_random_seed(seed)
num_att_heads = num_att_heads_per_partition * \
torch.distributed.get_world_size()
hidden_size = hidden_size_per_att_head * num_att_heads
intermediate_size = 4 * hidden_size
# Network
identity_layer = IdentityLayer3D(batch_size, sequence_length,
hidden_size).cuda()
transformer_layer = mpu.BertParallelTransformerLayer(
hidden_size, intermediate_size, num_att_heads, 0.0, 0.0,
torch.nn.functional.relu, 1.0e-5).cuda()
loss_weight = torch.randn([batch_size, sequence_length, hidden_size]).cuda()
attention_mask = torch.randn([batch_size, 1, 1, sequence_length]).cuda()
# Forward
input_ = identity_layer()
output = transformer_layer(input_, attention_mask)
loss = torch.mul(output, loss_weight).sum()
# Backward
loss.backward()
rank = mpu.get_model_parallel_rank()
mpu.destroy_model_parallel()
return rank, hidden_size, model_parallel_size, loss, \
transformer_layer, identity_layer
def test_parallel_transformer_layer(model_parallel_size):
if torch.distributed.get_rank() == 0:
print('> testing ParallelTransformerLayer with model parallel '
'size: {}'.format(model_parallel_size))
num_att_heads_per_partition = 3
hidden_size_per_att_head = 7
batch_size = 5
sequence_length = 13
rank_1, hidden_size_1, model_parallel_size_1, loss_1, \
transformer_layer_1, identity_layer_1 = parallel_transformer(
1, num_att_heads_per_partition,
hidden_size_per_att_head, batch_size, sequence_length)
rank, hidden_size, model_parallel_size, loss, \
transformer_layer, identity_layer = parallel_transformer(
model_parallel_size, num_att_heads_per_partition,
hidden_size_per_att_head, batch_size, sequence_length)
error = loss_1.sub(loss).abs().max()
torch.distributed.barrier()
print(' loss error on global rank {}: {}'.format(
torch.distributed.get_rank(), error))
assert error < 5.0e-5, 'error: {}'.format(error)
error = identity_layer_1.weight.grad.sub(
identity_layer.weight.grad).abs().max()
torch.distributed.barrier()
print(' input gradient error on global rank {}: {}'.format(
torch.distributed.get_rank(), error))
assert error < 5.0e-5, 'error: {}'.format(error)
torch.distributed.barrier()
if torch.distributed.get_rank() == 0:
print(' >> passed the test :-)')
if __name__ == '__main__':
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
initialize_distributed()
world_size = torch.distributed.get_world_size()
print_separator('test initialize affine weight')
model_parallel_size = 1
while model_parallel_size <= world_size:
test_initialize_affine_weight(model_parallel_size)
model_parallel_size *= 2
model_parallel_size = 1
while model_parallel_size <= world_size:
print_separator('test parallel embedding')
test_parallel_embedding(model_parallel_size)
model_parallel_size *= 2
print_separator('test column-parallel linear')
model_parallel_size = 1
while model_parallel_size <= world_size:
test_column_parallel_linear(model_parallel_size)
model_parallel_size *= 2
print_separator('test row-parallel linear')
model_parallel_size = 1
while model_parallel_size <= world_size:
test_row_parallel_linear(model_parallel_size)
model_parallel_size *= 2
print_separator('test parallel self-attention')
model_parallel_size = 1
while model_parallel_size <= world_size:
test_parallel_self_attention(model_parallel_size)
model_parallel_size *= 2
print_separator('test parallel transformer')
model_parallel_size = 1
while model_parallel_size <= world_size:
test_parallel_transformer_layer(model_parallel_size)
model_parallel_size *= 2
# coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
sys.path.append("../..")
import torch
import mpu
from commons import initialize_distributed
from commons import print_separator
def test_set_cuda_rng_state(model_parallel_size):
if torch.distributed.get_rank() == 0:
print('> testing set_rng_state with size {} ...'.
format(model_parallel_size))
mpu.initialize_model_parallel(model_parallel_size)
model_parallel_size = mpu.get_model_parallel_world_size()
size = 123
seed = 1234
torch.cuda.manual_seed(1234)
tensor = torch.cuda.FloatTensor(size)
# Get the state
rng_state = torch.cuda.get_rng_state()
rng_state_copy = rng_state.clone()
# Do some stuff.
for _ in range(5):
torch.randn(size, out=tensor)
result_1 = tensor.clone()
assert rng_state.sub(rng_state_copy).max() == 0
assert torch.cuda.get_rng_state().sub(rng_state_copy).max() > 0
# State should be different.
new_rng_state = torch.cuda.get_rng_state()
max_diff = new_rng_state.sub(rng_state).max()
print(' max diff in rng state (should be non-zero) on global rank {}: {}'.
format(torch.distributed.get_rank(), max_diff))
assert max_diff > 0
# Reset the rng state and do the same stuff.
mpu.random._set_cuda_rng_state(rng_state)
for _ in range(5):
torch.randn(size, out=tensor)
mpu.random._set_cuda_rng_state(rng_state)
for _ in range(5):
torch.randn(size, out=tensor)
result_2 = tensor.clone()
# Results should be the same
error = result_2.sub(result_1).abs().max()
print(' max error in generated tensors (should be zero) on '
'global rank {}: {}'.format(torch.distributed.get_rank(), error))
assert error < 1.0e-6
# Input state should have remained intact.
error = rng_state.sub(rng_state_copy).max()
print(' max error in rng state (should be zero) on global rank {}: {}'.
format(torch.distributed.get_rank(), error))
assert error == 0
# Reset groups
mpu.destroy_model_parallel()
torch.distributed.barrier()
if torch.distributed.get_rank() == 0:
print('>> passed the test :-)')
def test_cuda_rng_tracker(model_parallel_size):
if torch.distributed.get_rank() == 0:
print('> testing cuda rng tracker with size {} ...'.
format(model_parallel_size))
mpu.initialize_model_parallel(model_parallel_size)
model_parallel_size = mpu.get_model_parallel_world_size()
seed_1 = 1234
seed_2 = 4321
size = [12, 21]
tensor = torch.cuda.FloatTensor(size)
# Set to seed_1 and generate two tensors.
torch.cuda.manual_seed(seed_1)
torch.randn(size, out=tensor)
target_11 = tensor.clone()
torch.randn(size, out=tensor)
target_12 = tensor.clone()
# Set to seed_2 and generate two tensors.
torch.cuda.manual_seed(seed_2)
torch.randn(size, out=tensor)
target_21 = tensor.clone()
torch.randn(size, out=tensor)
target_22 = tensor.clone()
# Now if we interleave seed_1 and seed_2,
# we should still get the same tensors
torch.cuda.manual_seed(seed_1)
mpu.get_cuda_rng_tracker().add('test', seed_2)
torch.randn(size, out=tensor)
result_11 = tensor.clone()
with mpu.get_cuda_rng_tracker().fork('test'):
torch.randn(size, out=tensor)
result_21 = tensor.clone()
torch.randn(size, out=tensor)
result_12 = tensor.clone()
with mpu.get_cuda_rng_tracker().fork('test'):
torch.randn(size, out=tensor)
result_22 = tensor.clone()
diff = result_11.sub(result_21).abs().max()
diff = min(diff, result_12.sub(result_22).abs().max())
print(' max diff in generated tensors (should be non-zero) on '
'global rank {}: {}'.format(torch.distributed.get_rank(), diff))
assert diff > 1.0e-6
error = max(result_11.sub(target_11).abs().max(),
result_12.sub(target_12).abs().max())
error = max(error, result_21.sub(target_21).abs().max())
error = max(error, result_22.sub(target_22).abs().max())
print(' max error in generated tensors (should be zero) on '
'global rank {}: {}'.format(torch.distributed.get_rank(), error))
assert error < 1.0e-6
# Reset the tracker
mpu.get_cuda_rng_tracker().reset()
# Reset groups
mpu.destroy_model_parallel()
torch.distributed.barrier()
if torch.distributed.get_rank() == 0:
print('>> passed the test :-)')
def test_model_parallel_cuda_manual_seed(model_parallel_size):
if torch.distributed.get_rank() == 0:
print('> testing model parallel cuda manual seed with size {} ...'.
format(model_parallel_size))
mpu.initialize_model_parallel(model_parallel_size)
model_parallel_size = mpu.get_model_parallel_world_size()
mpu.model_parallel_cuda_manual_seed(12345)
assert torch.cuda.initial_seed() == 12345
with mpu.get_cuda_rng_tracker().fork():
assert torch.cuda.initial_seed() == (12345 + 2718 +
mpu.get_model_parallel_rank())
# Reset the tracker
mpu.get_cuda_rng_tracker().reset()
# Reset groups
mpu.destroy_model_parallel()
torch.distributed.barrier()
if torch.distributed.get_rank() == 0:
print('>> passed the test :-)')
if __name__ == '__main__':
initialize_distributed()
world_size = torch.distributed.get_world_size()
model_parallel_size = 1
while model_parallel_size <= world_size:
print_separator('test set rng state')
test_set_cuda_rng_state(model_parallel_size)
model_parallel_size *= 2
model_parallel_size = 1
while model_parallel_size <= world_size:
print_separator('test cuda rng tracker')
test_cuda_rng_tracker(model_parallel_size)
model_parallel_size *= 2
model_parallel_size = 1
while model_parallel_size <= world_size:
print_separator('test model parallel cuda manual seed')
test_model_parallel_cuda_manual_seed(model_parallel_size)
model_parallel_size *= 2
# coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Transformer."""
import math
import torch
import torch.nn.init as init
from torch.nn import LayerNorm
from .initialize import get_model_parallel_world_size
from .layers import ColumnParallelLinear
from .layers import RowParallelLinear
from .mappings import gather_from_model_parallel_region
from .random import checkpoint
from .random import get_cuda_rng_tracker
from .utils import divide
from .utils import split_tensor_along_last_dim
class PositionalEmbedding(torch.nn.Module):
def __init__(self, hidden_size):
super(PositionalEmbedding, self).__init__()
self.hidden_size = hidden_size
inv_freq = 1 / (10000 ** (torch.arange(0.0, hidden_size, 2.0) / hidden_size))
self.register_buffer('inv_freq', inv_freq)
def forward(self, pos_seq, bsz=None):
sinusoid_inp = torch.ger(pos_seq, self.inv_freq)
pos_emb = torch.cat([sinusoid_inp.sin(), sinusoid_inp.cos()], dim=-1)
if bsz is not None:
return pos_emb[None, :, :].expand(bsz, -1, -1)
else:
return pos_emb[None, :, :]
class GPT2ParallelSelfAttention(torch.nn.Module):
"""Parallel self-attention layer for GPT2.
Self-attention layer takes input with size [b, s, h] where b is
the batch size, s is the sequence lenght, and h is the hidden size
and creates output of the same size.
Arguments:
hidden_size: total hidden size of the layer (h).
num_attention_heads: number of attention heads (n). Note that we
require n to be divisible by number of GPUs
used to parallelize the model. Also, we
require hidden size to be divisible by n.
dropout_prob: dropout probability for the attention scores.
init_method: weight initialization.
output_layer_init_method: output layer initialization. If None, use
`init_method`.
We use the following notation:
h: hidden_size
n: num_attention_heads
p: number of partitions
np: n/p
hp: h/p
hn: h/n
b: batch size
s: sequence length
"""
def __init__(self, hidden_size, num_attention_heads,
attention_dropout_prob, output_dropout_prob,
init_method, output_layer_init_method=None, relative_encoding=False):
super(GPT2ParallelSelfAttention, self).__init__()
# Set output layer initialization if not provided.
if output_layer_init_method is None:
output_layer_init_method = init_method
# Per attention head and per partition values.
world_size = get_model_parallel_world_size()
self.hidden_size_per_partition = divide(hidden_size, world_size)
self.hidden_size_per_attention_head = divide(hidden_size,
num_attention_heads)
self.num_attention_heads_per_partition = divide(num_attention_heads,
world_size)
self.relative_encoding = relative_encoding
# Strided linear layer.
self.query_key_value = ColumnParallelLinear(hidden_size, 3 * hidden_size,
stride=3,
gather_output=False,
init_method=init_method)
if relative_encoding:
self.relative = ColumnParallelLinear(hidden_size, hidden_size, gather_output=False,
init_method=init_method)
# Dropout. Note that for a single iteration, this layer will generate
# different outputs on different number of parallel partitions but
# on average it should not be partition dependent.
self.attention_dropout = torch.nn.Dropout(attention_dropout_prob)
# Output.
self.dense = RowParallelLinear(hidden_size,
hidden_size,
input_is_parallel=True,
init_method=output_layer_init_method)
self.output_dropout = torch.nn.Dropout(output_dropout_prob)
try:
import deepspeed
if deepspeed.checkpointing.is_configured():
global get_cuda_rng_tracker, checkpoint
get_cuda_rng_tracker = deepspeed.checkpointing.get_cuda_rng_tracker
checkpoint = deepspeed.checkpointing.checkpoint
except ImportError:
pass
def _transpose_for_scores(self, tensor):
"""Transpose a 3D tensor [b, s, np*hn] into a 4D tensor with
size [b, np, s, hn].
"""
new_tensor_shape = tensor.size()[:-1] + \
(self.num_attention_heads_per_partition,
self.hidden_size_per_attention_head)
tensor = tensor.view(*new_tensor_shape)
return tensor.permute(0, 2, 1, 3)
@staticmethod
def _rel_shift(x, zero_triu=False):
# ql x kl x bsz x h
# bsz x h x ql x kl
zero_pad = torch.zeros((*x.size()[:-2], x.size(-2), 1),
device=x.device, dtype=x.dtype)
x_padded = torch.cat([zero_pad, x], dim=-1)
x_padded = x_padded.view(*x.size()[:-2], x.size(-1) + 1, x.size(-2))
x = x_padded[:, :, 1:].view_as(x)
if zero_triu:
ones = torch.ones((x.size(0), x.size(1)))
x = x * torch.tril(ones, x.size(1) - x.size(0))[:, :, None, None]
return x
@staticmethod
def _rel_shift_latest(x: torch.Tensor):
ndims = x.dim()
x_shape = x.size()
row_dim = 2
col_dim = row_dim + 1
assert col_dim < ndims
tgt_shape_1, tgt_shape_2 = [], []
for i in range(ndims):
if i == row_dim:
tgt_shape_1.append(x_shape[col_dim])
tgt_shape_2.append(x_shape[row_dim])
elif i == col_dim:
tgt_shape_1.append(x_shape[row_dim])
tgt_shape_2.append(x_shape[col_dim] - 1)
else:
tgt_shape_1.append(x_shape[i])
tgt_shape_2.append(x_shape[i])
x = x.view(*tgt_shape_1)
x = x[:, :, 1:, :]
x = x.view(*tgt_shape_2)
return x
def forward(self, hidden_states, ltor_mask, position_embeddings=None, r_w_bias=None, r_r_bias=None, mem=None):
# hidden_states: [b, s, h]
# ltor_mask: [1, 1, s, s]
# Attention heads. [b, s, hp]
query_length = hidden_states.size(1)
if mem is None:
mixed_x_layer = self.query_key_value(hidden_states)
(mixed_query_layer,
mixed_key_layer,
mixed_value_layer) = split_tensor_along_last_dim(mixed_x_layer, 3)
else:
cat = torch.cat((mem, hidden_states), 1)
mixed_x_layer = self.query_key_value(cat)
(mixed_query_layer,
mixed_key_layer,
mixed_value_layer) = split_tensor_along_last_dim(mixed_x_layer, 3)
mixed_query_layer = mixed_query_layer[:, -query_length:]
# Reshape and transpose [b, np, s, hn]
query_layer = self._transpose_for_scores(mixed_query_layer)
key_layer = self._transpose_for_scores(mixed_key_layer)
value_layer = self._transpose_for_scores(mixed_value_layer)
if self.relative_encoding:
relative_layer = self.relative(position_embeddings)
relative_layer = self._transpose_for_scores(relative_layer) # 1 (bsz) x n_head x klen x d_head
# Raw attention scores. [b, np, qs, ks]
rw_head_q = query_layer + r_w_bias.unsqueeze(1)
ac_score = torch.matmul(rw_head_q, key_layer.transpose(-1, -2))
rr_head_q = query_layer + r_r_bias.unsqueeze(1)
bd_score = torch.matmul(rr_head_q, relative_layer.transpose(-1, -2))
bd_score = self._rel_shift(bd_score) # qlen x klen x bsz x n_head
# bd_score = bd_score.permute(2, 3, 0, 1) # bsz n_head qlen klen
attention_scores = ac_score + bd_score
else:
# Raw attention scores. [b, np, s, s]
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
attention_scores = attention_scores / math.sqrt(
self.hidden_size_per_attention_head)
# Apply the left to right attention mask.
attention_scores = torch.mul(attention_scores, ltor_mask) - \
10000.0 * (1.0 - ltor_mask)
# Attention probabilities. [b, np, s, s]
attention_probs = torch.nn.Softmax(dim=-1)(attention_scores)
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
with get_cuda_rng_tracker().fork():
attention_probs = self.attention_dropout(attention_probs)
# Context layer.
# [b, np, s, hn]
context_layer = torch.matmul(attention_probs, value_layer)
# [b, s, np, hn]
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + \
(self.hidden_size_per_partition,)
# [b, s, hp]
context_layer = context_layer.view(*new_context_layer_shape)
# Output. [b, s, h]
output = self.dense(context_layer)
output = self.output_dropout(output)
return output
@torch.jit.script
def gelu_impl(x):
"""OpenAI's gelu implementation."""
return 0.5 * x * (1.0 + torch.tanh(0.7978845608028654 * x *
(1.0 + 0.044715 * x * x)))
def gelu(x):
return gelu_impl(x)
class GPT2ParallelMLP(torch.nn.Module):
"""MLP for GPT2.
MLP will take the input with h hidden state, project it to 4*h
hidden dimension, perform gelu transformation, and project the
state back into h hidden dimension. At the end, dropout is also
applied.
Arguments:
hidden_size: The hidden size of the self attention.
output_dropout_prob: dropout probability for the outputs
after self attention and final output.
init_method: initialization method used for the weights. Note
that all biases are initialized to zero and
layernorm weight are initialized to one.
output_layer_init_method: output layer initialization. If None,
use `init_method`.
"""
def __init__(self, hidden_size, output_dropout_prob, init_method,
output_layer_init_method=None):
super(GPT2ParallelMLP, self).__init__()
# Set output layer initialization if not provided.
if output_layer_init_method is None:
output_layer_init_method = init_method
# Project to 4h.
self.dense_h_to_4h = ColumnParallelLinear(hidden_size, 4 * hidden_size,
gather_output=False,
init_method=init_method)
# Project back to h.
self.dense_4h_to_h = RowParallelLinear(
4 * hidden_size,
hidden_size,
input_is_parallel=True,
init_method=output_layer_init_method)
self.dropout = torch.nn.Dropout(output_dropout_prob)
def forward(self, hidden_states):
# [b, s, 4hp]
intermediate_parallel = self.dense_h_to_4h(hidden_states)
intermediate_parallel = gelu(intermediate_parallel)
# [b, s, h]
output = self.dense_4h_to_h(intermediate_parallel)
output = self.dropout(output)
return output
class GPT2ParallelTransformerLayer(torch.nn.Module):
"""A single layer transformer for GPT2.
We use the following notation:
h: hidden size
n: number of attention heads
b: batch size
s: sequence length
Transformore layer takes input with size [b, s, h] and returns an
output of the same size.
Arguments:
hidden_size: The hidden size of the self attention.
num_attention_heads: number of attention head in the self
attention.
attention_dropout_prob: dropout probability of the attention
score in self attention.
output_dropout_prob: dropout probability for the outputs
after self attention and final output.
layernorm_epsilon: epsilon used in layernorm to avoid
division by zero.
init_method: initialization method used for the weights. Note
that all biases are initialized to zero and
layernorm weight are initialized to one.
output_layer_init_method: output layers (attention output and
mlp output) initialization. If None,
use `init_method`.
"""
def __init__(self,
hidden_size,
num_attention_heads,
attention_dropout_prob,
output_dropout_prob,
layernorm_epsilon,
init_method,
output_layer_init_method=None,
relative_encoding=False):
super(GPT2ParallelTransformerLayer, self).__init__()
# Set output layer initialization if not provided.
if output_layer_init_method is None:
output_layer_init_method = init_method
# Layernorm on the input data.
self.input_layernorm = LayerNorm(hidden_size, eps=layernorm_epsilon)
# Self attention.
self.attention = GPT2ParallelSelfAttention(
hidden_size,
num_attention_heads,
attention_dropout_prob,
output_dropout_prob,
init_method,
output_layer_init_method=output_layer_init_method,
relative_encoding=relative_encoding)
# Layernorm on the input data.
self.post_attention_layernorm = LayerNorm(hidden_size,
eps=layernorm_epsilon)
# MLP
self.mlp = GPT2ParallelMLP(
hidden_size,
output_dropout_prob,
init_method,
output_layer_init_method=output_layer_init_method)
def forward(self, hidden_states, ltor_mask, position_embeddings=None, r_w_bias=None, r_r_bias=None, mem=None):
# hidden_states: [b, s, h]
# ltor_mask: [1, 1, s, s]
# Layer norm at the begining of the transformer layer.
layernorm_output = self.input_layernorm(hidden_states)
mem = self.input_layernorm(mem) if mem is not None else None
# Self attention.
attention_output = self.attention(layernorm_output, ltor_mask, position_embeddings, r_w_bias, r_r_bias, mem)
# Residual connection.
layernorm_input = hidden_states + attention_output
# Layer norm post the self attention.
layernorm_output = self.post_attention_layernorm(layernorm_input)
# MLP.
mlp_output = self.mlp(layernorm_output)
# Second residual connection.
output = layernorm_input + mlp_output
return output
def unscaled_init_method(sigma):
"""Init method based on N(0, sigma)."""
def init_(tensor):
return torch.nn.init.normal_(tensor, mean=0.0, std=sigma)
return init_
def scaled_init_method(sigma, num_layers):
"""Init method based on N(0, sigma/sqrt(2*num_layers)."""
std = sigma / math.sqrt(2.0 * num_layers)
def init_(tensor):
return torch.nn.init.normal_(tensor, mean=0.0, std=std)
return init_
class GPT2ParallelTransformer(torch.nn.Module):
"""GPT-2 transformer.
This module takes input from embedding layer and it's output can
be used directly by a logit layer. It consists of L (num-layers)
blocks of:
layer norm
self attention
residual connection
layer norm
mlp
residual connection
followed by a final layer norm.
Arguments:
num_layers: Number of transformer layers.
hidden_size: The hidden size of the self attention.
num_attention_heads: number of attention head in the self
attention.
attention_dropout_prob: dropout probability of the attention
score in self attention.
output_dropout_prob: dropout probability for the outputs
after self attention and final output.
checkpoint_activations: if True, checkpoint activations.
checkpoint_num_layers: number of layers to checkpoint. This
is basically the chunk size in checkpoitning.
layernorm_epsilon: epsilon used in layernorm to avoid
division by zero.
init_method_std: standard deviation of the init method which has
the form N(0, std).
use_scaled_init_for_output_weights: If Ture use 1/sqrt(2*num_layers)
scaling for the output weights (
output of self attention and mlp).
"""
def __init__(self,
num_layers,
hidden_size,
num_attention_heads,
max_sequence_length,
max_memory_length,
embedding_dropout_prob,
attention_dropout_prob,
output_dropout_prob,
checkpoint_activations,
checkpoint_num_layers=1,
layernorm_epsilon=1.0e-5,
init_method_std=0.02,
use_scaled_init_for_output_weights=True,
relative_encoding=False):
super(GPT2ParallelTransformer, self).__init__()
# Store activation checkpoiting flag.
self.checkpoint_activations = checkpoint_activations
self.checkpoint_num_layers = checkpoint_num_layers
self.max_memory_length = max_memory_length
output_layer_init_method = None
if use_scaled_init_for_output_weights:
output_layer_init_method = scaled_init_method(init_method_std,
num_layers)
# Embeddings dropout
self.embedding_dropout = torch.nn.Dropout(embedding_dropout_prob)
self.relative_encoding = relative_encoding
if relative_encoding:
# Relative position embedding
self.position_embeddings = PositionalEmbedding(hidden_size)
# Per attention head and per partition values.
world_size = get_model_parallel_world_size()
self.hidden_size_per_attention_head = divide(hidden_size,
num_attention_heads)
self.num_attention_heads_per_partition = divide(num_attention_heads,
world_size)
self.r_w_bias = torch.nn.Parameter(
torch.Tensor(self.num_attention_heads_per_partition, self.hidden_size_per_attention_head))
self.r_w_bias.model_parallel = True
self.r_r_bias = torch.nn.Parameter(
torch.Tensor(self.num_attention_heads_per_partition, self.hidden_size_per_attention_head))
self.r_r_bias.model_parallel = True
# Always initialize bias to zero.
with torch.no_grad():
self.r_w_bias.zero_()
self.r_r_bias.zero_()
else:
# Position embedding (serial).
self.position_embeddings = torch.nn.Embedding(max_sequence_length,
hidden_size)
# Initialize the position embeddings.
torch.nn.init.normal_(self.position_embeddings.weight, mean=0.0, std=init_method_std)
def get_layer():
return GPT2ParallelTransformerLayer(
hidden_size,
num_attention_heads,
attention_dropout_prob,
output_dropout_prob,
layernorm_epsilon,
unscaled_init_method(init_method_std),
output_layer_init_method=output_layer_init_method,
relative_encoding=relative_encoding)
# Transformer layers.
self.layers = torch.nn.ModuleList(
[get_layer() for _ in range(num_layers)])
# Final layer norm before output.
self.final_layernorm = LayerNorm(hidden_size, eps=layernorm_epsilon)
try:
import deepspeed
if deepspeed.checkpointing.is_configured():
global get_cuda_rng_tracker, checkpoint
get_cuda_rng_tracker = deepspeed.checkpointing.get_cuda_rng_tracker
checkpoint = deepspeed.checkpointing.checkpoint
except ImportError:
pass
def forward(self, hidden_states, position_ids, attention_mask, *mems):
batch_size, query_length = hidden_states.size()[:2]
memory_length = mems[0].size(1) if mems else 0
key_length = query_length + memory_length
attention_mask = attention_mask[:, :, :, -query_length - memory_length:]
if self.relative_encoding:
hidden_states = self.embedding_dropout(hidden_states)
position_sequence = torch.arange(key_length - 1, -1, -1.0, device=hidden_states.device,
dtype=hidden_states.dtype)
position_embeddings = self.position_embeddings(position_sequence)
# Apply dropout
position_embeddings = self.embedding_dropout(position_embeddings)
hidden_states = self.embedding_dropout(hidden_states)
else:
position_embeddings = self.position_embeddings(position_ids)
hidden_states = hidden_states + position_embeddings
hidden_states = self.embedding_dropout(hidden_states)
if self.max_memory_length > 0:
mem_layers = [hidden_states.detach()]
else:
mem_layers = []
def custom(start, end):
def custom_forward(*inputs):
layers_ = self.layers[start:end]
x_, inputs = inputs[0], inputs[1:]
if self.relative_encoding:
inputs, mems_ = inputs[:4], inputs[4:]
else:
inputs, mems_ = inputs[:1], inputs[1:]
for i, layer in enumerate(layers_):
mem_i_ = mems_[i] if mems_ else None
x_ = layer(x_, *inputs, mem=mem_i_)
if self.max_memory_length > 0:
mem_layers.append(x_.detach())
return x_
return custom_forward
if self.checkpoint_activations:
l = 0
num_layers = len(self.layers)
chunk_length = self.checkpoint_num_layers
while l < num_layers:
args = [hidden_states, attention_mask]
if self.relative_encoding:
args += [position_embeddings, self.r_w_bias, self.r_r_bias]
if mems:
args += mems[l: l + chunk_length]
hidden_states = checkpoint(custom(l, l + chunk_length), *args)
l += chunk_length
else:
for i, layer in enumerate(self.layers):
args = [hidden_states, attention_mask]
if self.relative_encoding:
args += [position_embeddings, self.r_w_bias, self.r_r_bias]
mem_i = mems[i] if mems else None
hidden_states = layer(*args, mem=mem_i)
if self.max_memory_length > 0:
mem_layers.append(hidden_states.detach())
# Final layer norm.
output = self.final_layernorm(hidden_states)
if self.max_memory_length > 0:
mem_layers = self.update_mems(mem_layers, mems)
return (output, *mem_layers)
def update_mems(self, hiddens, mems):
memory_length = mems[0].size(1) if mems else 0
query_length = hiddens[0].size(1)
new_memory_length = min(self.max_memory_length, memory_length + query_length)
new_mems = []
with torch.no_grad():
for i in range(len(hiddens)):
if new_memory_length <= query_length:
new_mems.append(hiddens[i][:, -new_memory_length:])
else:
new_mems.append(torch.cat((mems[i][:, -new_memory_length + query_length:], hiddens[i]), dim=1))
return new_mems
class BertParallelSelfAttention(torch.nn.Module):
"""Parallel self-attention layer for BERT.
Self-attention layer takes input with size [b, s, h] where b is
the batch size, s is the sequence lenght, and h is the hidden size
and creates output of the same size.
Arguments:
hidden_size: total hidden size of the layer (h).
num_attention_heads: number of attention heads (n). Note that we
require n to be divisible by number of GPUs
used to parallelize the model. Also, we
require hidden size be divisible by n.
dropout_prob: dropout probability for the attention scores.
output_parallel: If true, no all-gather is done on the output and
the output values will be per partition.
We use the following notation:
h: hidden_size
n: num_attention_heads
p: number of partitions
np: n/p
hp: h/p
hn: h/n
b: batch size
s: sequence length
"""
def __init__(self, hidden_size, num_attention_heads,
dropout_prob, output_parallel=False,
init_method=init.xavier_normal_):
super(BertParallelSelfAttention, self).__init__()
# Input configuration.
self.hidden_size = hidden_size
self.num_attention_heads = num_attention_heads
self.dropout_prob = dropout_prob
self.output_parallel = output_parallel
# Per attention head and per partition values.
world_size = get_model_parallel_world_size()
self.hidden_size_per_partition = divide(hidden_size, world_size)
self.hidden_size_per_attention_head = divide(hidden_size,
num_attention_heads)
self.num_attention_heads_per_partition = divide(num_attention_heads,
world_size)
# Strided linear layer.
self.query_key_value = ColumnParallelLinear(hidden_size, 3 * hidden_size,
stride=3,
gather_output=False,
init_method=init_method)
# Dropout. Note that for a single iteration, this layer will generate
# different outputs on different number of parallel partitions but
# on average it should not be partition dependent.
self.dropout = torch.nn.Dropout(dropout_prob)
try:
import deepspeed
if deepspeed.checkpointing.is_configured():
global get_cuda_rng_tracker, checkpoint
get_cuda_rng_tracker = deepspeed.checkpointing.get_cuda_rng_tracker
checkpoint = deepspeed.checkpointing.checkpoint
except ImportError:
pass
def _transpose_for_scores(self, tensor):
"""Transpose a 3D tensor [b, s, np*hn] into a 4D tensor with
size [b, np, s, hn].
"""
new_tensor_shape = tensor.size()[:-1] + \
(self.num_attention_heads_per_partition,
self.hidden_size_per_attention_head)
tensor = tensor.view(*new_tensor_shape)
return tensor.permute(0, 2, 1, 3)
def forward(self, hidden_states, attention_mask):
# Attention heads. [b, s, hp]
mixed_x_layer = self.query_key_value(hidden_states)
(mixed_query_layer,
mixed_key_layer,
mixed_value_layer) = split_tensor_along_last_dim(mixed_x_layer, 3)
# Reshape and transpose [b, np, s, hn]
query_layer = self._transpose_for_scores(mixed_query_layer)
key_layer = self._transpose_for_scores(mixed_key_layer)
value_layer = self._transpose_for_scores(mixed_value_layer)
# Raw attention scores. [b, np, s, s]
norm_factor = math.sqrt(math.sqrt(self.hidden_size_per_attention_head))
attention_scores = torch.matmul(query_layer / norm_factor,
key_layer.transpose(-1, -2) / norm_factor)
# Apply the attention mask.
attention_scores += attention_mask
# Attention probabilities. [b, np, s, s]
attention_probs = torch.nn.Softmax(dim=-1)(attention_scores)
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
with get_cuda_rng_tracker().fork():
attention_probs = self.dropout(attention_probs)
# Context layer.
# [b, np, s, hn]
context_layer = torch.matmul(attention_probs, value_layer)
# [b, s, np, hn]
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + \
(self.hidden_size_per_partition,)
# [b, s, hp]
context_layer = context_layer.view(*new_context_layer_shape)
# Output. [b, s, h]
if self.output_parallel:
output = context_layer
else:
output = gather_from_model_parallel_region(context_layer)
return output
class BertParallelTransformerOutput(torch.nn.Module):
"""The output layer used after self attention and intermediate
parts of transformer layer."""
def __init__(self, input_size, output_size, dropout_prob,
layernorm_epsilon=1.0e-12, input_is_parallel=False,
init_method=init.xavier_normal_):
super(BertParallelTransformerOutput, self).__init__()
# Components.
self.dense = RowParallelLinear(input_size,
output_size,
input_is_parallel=input_is_parallel,
init_method=init_method)
self.dropout = torch.nn.Dropout(dropout_prob)
self.layernorm = LayerNorm(output_size, eps=layernorm_epsilon)
def forward(self, hidden_states, input_tensor):
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states)
layernorm_input = hidden_states + input_tensor
hidden_states = self.layernorm(layernorm_input)
return hidden_states
class BertParallelTransformerLayer(torch.nn.Module):
"""A single layer transformer for Bert.
We use the following notation:
h: hidden size
n: number of attention heads
b: batch size
s: sequence length
Transformore layer takes input with size [b, s, h] and returns an
output of the same size.
Arguments:
hidden_size: The hidden size of the self attention.
intermediate_size: size of the intermediate state after
self attention. In both BERT and GPT
this is set to be 4 times the hidden
size.
num_attention_heads: number of attention head in the self
attention.
attention_dropout_prob: dropout probability of the attention
score in self attention.
output_dropout_prob: dropout probability for the outputs
after self attention and final output.
intermediate_activation_fn: activation function for output
of intermediate.
layernorm_epsilon: epsilon used in layernorm to avoid
division by zero.
init_method: initialization method used for the weights. Note
that all biases are initialized to zero and
layernorm weight are initialized to one.
"""
def __init__(self,
hidden_size,
intermediate_size,
num_attention_heads,
attention_dropout_prob,
output_dropout_prob,
intermediate_activation_fn,
layernorm_epsilon,
init_method=init.xavier_normal_):
super(BertParallelTransformerLayer, self).__init__()
# Self attention.
self.attention = BertParallelSelfAttention(hidden_size,
num_attention_heads,
attention_dropout_prob,
output_parallel=True,
init_method=init_method)
# Self attention output.
self.self_output = BertParallelTransformerOutput(
hidden_size, hidden_size, output_dropout_prob,
layernorm_epsilon=layernorm_epsilon,
input_is_parallel=True,
init_method=init_method)
# Intermediate.
self.intermediate = ColumnParallelLinear(hidden_size, intermediate_size,
gather_output=False,
init_method=init_method)
self.intermediate_activation_fn = intermediate_activation_fn
# Output.
self.output = BertParallelTransformerOutput(
intermediate_size, hidden_size, output_dropout_prob,
layernorm_epsilon=layernorm_epsilon,
input_is_parallel=True,
init_method=init_method)
def forward(self, hidden_states, attention_mask):
# [b, s, hp]
attention_output_parallel = self.attention(hidden_states,
attention_mask)
# [b, s, h]
attention_self_output = self.self_output(attention_output_parallel,
hidden_states)
# [b, s, ip]
intermediate_output_parallel = self.intermediate(attention_self_output)
intermediate_output_parallel = self.intermediate_activation_fn(
intermediate_output_parallel)
# [b, s, h]
layer_output = self.output(intermediate_output_parallel,
attention_self_output)
return layer_output
# coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
def ensure_divisibility(numerator, denominator):
"""Ensure that numerator is divisible by the denominator."""
assert numerator % denominator == 0, '{} is not divisible by {}'.format(
numerator, denominator)
def divide(numerator, denominator):
"""Ensure that numerator is divisible by the denominator and return
the division value."""
ensure_divisibility(numerator, denominator)
return numerator // denominator
def split_tensor_along_last_dim(tensor, num_partitions,
contiguous_split_chunks=False):
"""Split a tensor along its last dimension.
Arguments:
tensor: input tensor.
num_partitions: number of partitions to split the tensor
contiguous_split_chunks: If True, make each chunk contiguous
in memory.
"""
# Get the size and dimension.
last_dim = tensor.dim() - 1
last_dim_size = divide(tensor.size()[last_dim], num_partitions)
# Split.
tensor_list = torch.split(tensor, last_dim_size, dim=last_dim)
# Note: torch.split does not create contiguous tensors by default.
if contiguous_split_chunks:
return tuple(chunk.contiguous() for chunk in tensor_list)
return tensor_list
class VocabUtility:
"""Split the vocabulary into `world_size` chunks amd return the
first and last index of the vocabulary belonging to the `rank`
partition: Note that indecies in [fist, last)"""
@staticmethod
def vocab_range_from_per_partition_vocab_size(per_partition_vocab_size,
rank, world_size):
index_f = rank * per_partition_vocab_size
index_l = index_f + per_partition_vocab_size
return index_f, index_l
@staticmethod
def vocab_range_from_global_vocab_size(global_vocab_size, rank, world_size):
per_partition_vocab_size = divide(global_vocab_size, world_size)
return VocabUtility.vocab_range_from_per_partition_vocab_size(
per_partition_vocab_size, rank, world_size)
#!/bin/bash
CHECKPOINT_PATH="./checkpoint_model"
PREFIX_SCRIPT_PATH=$1
MPSIZE=1
NLAYERS=32
NHIDDEN=2560
NATT=32
MAXSEQLEN=1024
MASTER_PORT=$(shuf -n 1 -i 10000-65535)
#SAMPLING ARGS
TEMP=0.9
#If TOPK/TOPP are 0 it defaults to greedy sampling, top-k will also override top-p
TOPK=0
TOPP=0.9
script_path=$(realpath $0)
script_dir=$(dirname $script_path)
config_json="$script_dir/ds_config.json"
OPTIONS_NCCL="NCCL_DEBUG=info NCCL_IB_DISABLE=0 NCCL_NET_GDR_LEVEL=2"
HOST_FILE_PATH="/desktop/hostfile"
STATE_DICT_PATH="./transformer-xl-cmrc-prefix-step-4000/mp_rank_00_model_states.pt"
RUN_OPTIONS="
--MASTER_PORT $MASTER_PORT
--model-parallel-size $MPSIZE \
--num-layers $NLAYERS \
--hidden-size $NHIDDEN \
--load $CHECKPOINT_PATH \
--num-attention-heads $NATT \
--max-position-embeddings 1024 \
--tokenizer-type ChineseSPTokenizer \
--fp16 \
--cache-dir cache \
--out-seq-length $MAXSEQLEN \
--seq-length 512 \
--mem-length 256 \
--transformer-xl \
--temperature $TEMP \
--top_k $TOPK \
--top_p $TOPP \
--batch_size 1 \
--deepspeed \
--deepspeed-activation-checkpointing \
--deepspeed_config './scripts/ds_config_2.9B_finetune.json' \
--save_dir "./checkpoint_model" \
--preseqlen 200 \
--no_load_lr_scheduler \
--no_load_optim \
--finetune \
"
# --load_sd $STATE_DICT_PATH \
RUN_CMD="$OPTIONS_NCCL deepspeed --include="localhost:1,2,3,4" --hostfile=$HOST_FILE_PATH $PREFIX_SCRIPT_PATH ${RUN_OPTIONS}"
echo ${RUN_CMD}
eval ${RUN_CMD}
from generate_samples import *
from data_utils.sp_tokenizer import *
from data_utils.tokenization import *
from direct_beam_search import *
from utils import print_rank_0
import pickle
from transformers import Adafactor, get_cosine_with_hard_restarts_schedule_with_warmup
from data_utils.tokenization import CommandToken
from generate_samples import initialize_distributed, setup_model, set_random_seed
import numpy as np
from torch.utils.tensorboard import SummaryWriter
import os
import time
import torch.nn.functional as F
import torch.nn as nn
class PrefixTuning(nn.Module):
def __init__(self, model, num_layer, embd_dim, device, preseqlen=100, prefix_dropout=0.0):
super(PrefixTuning, self).__init__()
self.device = device
self.match_n_layer = num_layer
self.n_embd = embd_dim
self.preseqlen = preseqlen
self.prefix_dropout = prefix_dropout
self.model = model
print('PrefixTuning')
print('preseqlen is {}, optimizing the prefix directly'.format(self.preseqlen))
# plus one because we need to take the embedding layer into account
self.soft_prompt_layers = torch.nn.ModuleList([nn.Linear(self.preseqlen, self.n_embd) for _ in range(self.match_n_layer + 1)])
# self.dropout = nn.Dropout(self.prefix_dropout)
# freeze all model params so that only prefixes are trained!!
for param in model.parameters():
param.requires_grad = False
def get_prompt(self, bsz):
soft_prompts = []
identity_matrix = torch.eye(self.preseqlen, device=self.device, dtype=torch.half)
for idx, layer in enumerate(self.soft_prompt_layers):
prefix_embeddings = layer(identity_matrix).unsqueeze(0).expand(bsz, -1, -1).to(self.device)
soft_prompts.append(prefix_embeddings)
return soft_prompts
def forward(self, input_ids, position_ids, attention_mask, *mems):
if not mems:
mems = self.get_prompt(bsz=input_ids.shape[0]) # num_layer * bsz * seq_len * hidden_dim
return self.model(input_ids, position_ids, attention_mask, *mems)
def setup_model(args):
"""Setup model and optimizer."""
model = GPT2Model(num_layers=args.num_layers,
vocab_size=args.vocab_size,
hidden_size=args.hidden_size,
num_attention_heads=args.num_attention_heads,
embedding_dropout_prob=args.hidden_dropout,
attention_dropout_prob=args.attention_dropout,
output_dropout_prob=args.hidden_dropout,
max_sequence_length=args.max_position_embeddings,
max_memory_length=args.mem_length,
checkpoint_activations=args.checkpoint_activations,
checkpoint_num_layers=args.checkpoint_num_layers,
parallel_output=False,
relative_encoding=args.transformer_xl)
if not args.load_sd:
# load the transformer-xl
args.deepspeed = False
_ = load_checkpoint(
model, None, None, args, load_optimizer_states=False)
args.deepspeed = True
model = PrefixTuning(model=model, num_layer=args.num_layers, embd_dim=args.hidden_size,
device=args.local_rank, preseqlen=args.preseqlen, prefix_dropout=0.1)
model.cuda(torch.cuda.current_device())
if args.load_sd:
print(f"loading state_dict {args.load_sd}")
sd = torch.load(args.load_sd, map_location='cpu')
model.load_state_dict(sd["module"])
if hasattr(args, "deepspeed") and args.deepspeed and args.fp16:
model.half()
# Fp16 conversion.
if args.fp16:
model = FP16_Module(model)
if args.deepspeed:
print("DeepSpeed is enabled.")
model, _, _, _ = deepspeed.initialize(
model=model,
model_parameters=model.parameters(),
args=args,
mpu=mpu,
dist_init_required=False
)
return model
def prepare_envs():
"""Prepare model for generation."""
print('Preparing Model')
args = get_args()
# Disable CuDNN.
torch.backends.cudnn.enabled = False
# Arguments.
args.mem_length = args.seq_length + args.mem_length - 1
# # Pytorch distributed.
initialize_distributed(args)
# Random seeds for reproducability.
set_random_seed(args.seed)
# get the tokenizer
tokenizer = prepare_tokenizer(args)
# Model, optimizer, and learning rate.
model = setup_model(args)
return model, tokenizer, torch.cuda.current_device(), args
def get_processed_data(data_dir, max_seq_len=512, min_seq_len=256):
data = []
with os.scandir(data_dir) as it:
for idx, entry in enumerate(it):
if not entry.name.startswith('.') and entry.is_file():
with open(entry.path) as file:
lines = file.readlines()
lines = [line.strip() for line in lines]
doc_data = "".join(lines)
if len(doc_data) <= max_seq_len and len(doc_data)>= min_seq_len:
data.append(doc_data)
print(len(data))
return data
def get_processed_dataset(input_data, tokenizer):
dataset = []
for doc_data in input_data:
doc_tkns = tokenizer.EncodeAsIds(doc_data)
# important, we need to explicitly tell the model when to stop !
final_tkns = doc_tkns.append(tokenizer.get_command('eos'))
dataset.append(final_tkns.tokenization)
return dataset
class CMRCDataset(torch.utils.data.dataset.Dataset):
def __init__(self, dataset, device):
self.dataset = dataset
self.device = device
def __getitem__(self, index):
item = self.dataset[index]
x = torch.tensor(item, device=self.device)
return {"context":x}
def __len__(self):
return len(self.dataset)
class MyCollator(object):
def __init__(self, tokenizer, device, mem_length):
self.tokenizer = tokenizer
self.device = device
self.mem_length = mem_length
def __call__(self, batch):
# pad with specific pad id
data = [item["context"] for item in batch]
data = torch.nn.utils.rnn.pad_sequence(data, batch_first=True, padding_value=self.tokenizer.get_command('eos').Id)
seq_length = data.shape[1]
# get loss mask, mask padding values
loss_mask = torch.ones(data.shape, dtype=torch.float, device=self.device)
# loss_mask[data == tokenizer.get_command('eos').Id] = 0.0 # comment out so that eos behavior can be learned
# get position ids
position_ids = torch.arange(seq_length, dtype=torch.long, device=self.device)
position_ids = position_ids.unsqueeze(0).expand_as(data)
# get attention masks, no idea why it construct in this way, following code in
attention_mask = torch.ones((1, seq_length, seq_length + self.mem_length), device=self.device)
attention_mask = torch.tril(torch.triu(attention_mask, 1 - seq_length + self.mem_length), self.mem_length)
attention_mask = attention_mask.unsqueeze(1)
return data, attention_mask, loss_mask, position_ids
def init_model_with_no_grad_invoke():
with torch.no_grad():
with open("./checkpoint_model/tokens.pkl", "rb") as file:
tokens = pickle.load(file)
with open("./checkpoint_model/attn_mask.pkl", "rb") as file:
attention_mask = pickle.load(file)
with open("./checkpoint_model/pos_ids.pkl", "rb") as file:
position_ids = pickle.load(file)
mems = []
return model(tokens, position_ids, attention_mask, *mems)
def train(model, train_loader, args):
writer = SummaryWriter(log_dir=f"./prefix_tuning_nomlp_runs_{args.preseqlen}")
model.train()
for epoch in range(1):
batch_step = 0
prev = time.time()
for data, attention_mask, loss_mask, position_ids in train_loader:
batch_step += 1
mems = [] # useless in here
logits, *mems = model(data, position_ids, attention_mask, *mems)
# ignore padded tokens if batch_size is not 1
logits, data = logits[:, :-1], data[:, 1:] # IMPORTANT! If not offset by one the model will learn to predict itself!
r_logits, r_data = logits.reshape(logits.shape[0]*logits.shape[1], -1), data.reshape(data.shape[0]*data.shape[1])
loss = F.cross_entropy(r_logits, r_data) # should not have ignore_index=cmd_tokens["pad"].Id
model.backward(loss)
loss_item = loss.detach().item()
if np.isnan(loss_item):
raise("loss is nan, aborting .......... ")
writer.add_scalar("Loss/Train", loss_item, batch_step)
model.step()
if batch_step != 0 and batch_step % 100 == 0:
now = time.time()
print(f"Batch step : {batch_step}, elapsed time: {now-prev}")
prev = now
if batch_step % 1000 == 0:
ckpt_id = f"transformer-xl-prefix{args.preseqlen}-tuning-nomlp-step-{batch_step}"
model.save_checkpoint(args.save_dir, ckpt_id)
if __name__ == "__main__":
model, tokenizer, device, args = prepare_envs()
sport_data_dir = "/cognitive_comp/zengzhongshen/desktop/transformer_xl_chinese/THUCNews/体育"
sport_data = get_processed_data(sport_data_dir)
train_dataset = get_processed_dataset(sport_data, tokenizer)
my_collator = MyCollator(tokenizer, device, args.mem_length)
train_loader = torch.utils.data.DataLoader(dataset=CMRCDataset(train_dataset, device),
batch_size=args.batch_size, shuffle=True, collate_fn=my_collator)
# not sure if still needed # this is necessary, cuz if not init with no grad context, a reference error will appear
#res = init_model_with_no_grad_invoke()
train(model, train_loader, args)
\ No newline at end of file
{
"train_micro_batch_size_per_gpu": 16,
"gradient_accumulation_steps": 1,
"steps_per_print": 100,
"gradient_clipping": 1.0,
"zero_optimization": {
"stage": 2,
"contiguous_gradients": false,
"overlap_comm": true,
"reduce_scatter": true,
"reduce_bucket_size": 50000000,
"allgather_bucket_size": 500000000
},
"zero_allow_untested_optimizer": true,
"fp16": {
"enabled": true,
"loss_scale": 0,
"loss_scale_window": 1000,
"hysteresis": 2,
"min_loss_scale": 1
},
"optimizer": {
"type": "Adam",
"params": {
"lr": 0.0001,
"betas": [
0.9,
0.95
],
"eps": 1e-8,
"weight_decay": 1e-2
}
},
"activation_checkpointing": {
"partition_activations": false,
"contiguous_memory_optimization": false
},
"wall_clock_breakdown": false
}
\ No newline at end of file
{
"train_micro_batch_size_per_gpu": 2 ,
"gradient_accumulation_steps": 1,
"steps_per_print": 100,
"gradient_clipping": 1.0,
"zero_optimization": {
"stage": 2,
"offload_optimizer": {
"device": "cpu",
"pin_memory": true
},
"allgather_partitions": true,
"allgather_bucket_size": 5e8,
"overlap_comm": true,
"reduce_scatter": true,
"reduce_bucket_size": 5e8,
"contiguous_gradients": true
},
"zero_allow_untested_optimizer": true,
"fp16": {
"enabled": true,
"loss_scale": 0,
"loss_scale_window": 1000,
"hysteresis": 2,
"min_loss_scale": 1
},
"optimizer": {
"type": "Adam",
"params": {
"lr": 1e-5,
"betas": [
0.9,
0.95
],
"eps": 1e-8,
"weight_decay": 1e-2
}
},
"activation_checkpointing": {
"partition_activations": false,
"contiguous_memory_optimization": false
},
"wall_clock_breakdown": false
}
\ No newline at end of file
#! /bin/bash
# Change for multinode config
NUM_WORKERS=1
NUM_GPUS_PER_WORKER=1
MP_SIZE=1
script_path=$(realpath $0)
script_dir=$(dirname $script_path)
OPTIONS_NCCL="NCCL_DEBUG=info NCCL_IB_DISABLE=0 NCCL_NET_GDR_LEVEL=2"
HOST_FILE_PATH="/root/code/config/hostfile"
config_json="$script_dir/ds_config_2.9B_finetune.json"
gpt_options=" \
--finetune \
--experiment-name txl-2.9b \
--model-parallel-size ${MP_SIZE} \
--num-layers 32 \
--hidden-size 2560 \
--num-attention-heads 32 \
--seq-length 512 \
--max-position-embeddings 512 \
--mem-length 256 \
--load ${1} \
--no-load-optim \
--save ./checkpoints \
--save-interval 2000 \
--train-iters 10000 \
--resume-dataloader \
--train-data ${2} \
--xl-dataset \
--lazy-loader \
--tokenizer-type ChineseSPTokenizer \
--split 949,50,1 \
--distributed-backend nccl \
--lr-decay-style cosine \
--lr-decay-ratio 0.05 \
--lr-decay-iters 10000 \
--warmup .1 \
--checkpoint-activations \
--deepspeed-activation-checkpointing \
--transformer-xl \
--fp16 \
"
gpt_options="${gpt_options}
--deepspeed \
--deepspeed_config ${config_json} \
"
run_cmd="${OPTIONS_NCCL} deepspeed --num_nodes ${NUM_WORKERS} --num_gpus ${NUM_GPUS_PER_WORKER} --hostfile ${HOST_FILE_PATH} pretrain_gpt2.py ${gpt_options}"
echo ${run_cmd}
eval ${run_cmd}
set +x
\ No newline at end of file
#! /bin/bash
# Change for multinode config
NUM_WORKERS=8
NUM_GPUS_PER_WORKER=8
MP_SIZE=1
script_path=$(realpath $0)
script_dir=$(dirname $script_path)
OPTIONS_NCCL="NCCL_DEBUG=info NCCL_IB_DISABLE=0 NCCL_SOCKET_IFNAME=bond0 NCCL_IB_GID_INDEX=3 NCCL_NET_GDR_LEVEL=0"
HOST_FILE_PATH="/root/code/config/pre_hostfile"
#OPTIONS_NCCL=""
#HOST_FILE_PATH="/workspace/hostfile"
config_json="$script_dir/ds_config_2.9B.json"
gpt_options=" \
--experiment-name txl-2.9b \
--model-parallel-size ${MP_SIZE} \
--num-layers 32 \
--hidden-size 2560 \
--num-attention-heads 32 \
--seq-length 512 \
--max-position-embeddings 512 \
--mem-length 256 \
--load /root/data/checkpoints/txl-2.8b11-20-15-10 \
--save /root/data/checkpoints \
--save-interval 2000 \
--train-iters 300000 \
--resume-dataloader \
--train-data zhihu baike zhidao \
--lazy-loader \
--tokenizer-type ChineseSPTokenizer \
--split 949,50,1 \
--distributed-backend nccl \
--lr-decay-style cosine \
--lr-decay-ratio 0.1 \
--lr-decay-iters 300000 \
--warmup .01 \
--checkpoint-activations \
--deepspeed-activation-checkpointing \
--transformer-xl \
--fp16 \
"
gpt_options="${gpt_options}
--deepspeed \
--deepspeed_config ${config_json} \
"
run_cmd="${OPTIONS_NCCL} deepspeed --num_nodes ${NUM_WORKERS} --num_gpus ${NUM_GPUS_PER_WORKER} --hostfile ${HOST_FILE_PATH} pretrain_gpt2.py $@ ${gpt_options}"
echo ${run_cmd}
eval ${run_cmd}
set +x
#!/bin/bash
CHECKPOINT_PATH=$1
MPSIZE=1
NLAYERS=32
NHIDDEN=2560
NATT=32
MAXSEQLEN=1024
MASTER_PORT=$(shuf -n 1 -i 10000-65535)
#SAMPLING ARGS
TEMP=0.9
#If TOPK/TOPP are 0 it defaults to greedy sampling, top-k will also override top-p
TOPK=0
TOPP=0
script_path=$(realpath $0)
script_dir=$(dirname $script_path)
config_json="$script_dir/ds_config.json"
python -m torch.distributed.launch --nproc_per_node=$MPSIZE --master_port $MASTER_PORT generate_samples.py \
--model-parallel-size $MPSIZE \
--num-layers $NLAYERS \
--hidden-size $NHIDDEN \
--load $CHECKPOINT_PATH \
--num-attention-heads $NATT \
--max-position-embeddings 1024 \
--tokenizer-type ChineseSPTokenizer \
--fp16 \
--cache-dir cache \
--out-seq-length $MAXSEQLEN \
--seq-length 512 \
--mem-length 256 \
--transformer-xl \
--temperature $TEMP \
--top_k $TOPK \
--top_p $TOPP
#!/bin/bash
CHECKPOINT_PATH=$1
MPSIZE=1
NLAYERS=24
NHIDDEN=1024
NATT=16
MAXSEQLEN=512
MASTER_PORT=$(shuf -n 1 -i 10000-65535)
#SAMPLING ARGS
TEMP=0.9
#If TOPK/TOPP are 0 it defaults to greedy sampling, top-k will also override top-p
TOPK=0
TOPP=0
script_path=$(realpath $0)
script_dir=$(dirname $script_path)
config_json="$script_dir/ds_config.json"
python -m torch.distributed.launch --nproc_per_node=$MPSIZE --master_port $MASTER_PORT generate_samples.py \
--model-parallel-size $MPSIZE \
--num-layers $NLAYERS \
--hidden-size $NHIDDEN \
--load $CHECKPOINT_PATH \
--num-attention-heads $NATT \
--max-position-embeddings 512 \
--tokenizer-type ChineseSPTokenizer \
--fp16 \
--cache-dir cache \
--out-seq-length $MAXSEQLEN \
--seq-length 512 \
--temperature $TEMP \
--top_k $TOPK \
--top_p $TOPP
\ No newline at end of file
"""
Usage:
python scripts/presplit_sentences_json.py <original loose json file> <output loose json file>
"""
import sys
import json
import nltk
nltk.download('punkt')
input_file = sys.argv[1]
output_file = sys.argv[2]
line_seperator = "\n"
with open(input_file, 'r') as ifile:
with open(output_file, "w") as ofile:
for doc in ifile.readlines():
parsed = json.loads(doc)
sent_list = []
for line in parsed['text'].split('\n'):
if line != '\n':
sent_list.extend(nltk.tokenize.sent_tokenize(line))
parsed['text'] = line_seperator.join(sent_list)
ofile.write(json.dumps(parsed)+'\n')
"""
Takes a corpora of files (specified by `--input_files`) with json data separated
by newlines (loose json). Splits data into train.json, val.json, test.json files
under `output_dir`.
Note: This code has the potential to override files with the names
train.json, val.json, test.json in `--output_dir`.
"""
import os
import argparse
import math
import random
parser = argparse.ArgumentParser('resplit loose json data into train/val/test')
parser.add_argument('--input_files', nargs='+', required=True,
help='whitespace separated list of input data files')
parser.add_argument('--output_dir', required=True,
help='output directory where to put files')
parser.add_argument('--test_percent', type=float, nargs='+', default=[0.05, 0],
help='percentage of available data to use for val/test dataset')
args = parser.parse_args()
def get_lines(filepath):
lines = []
with open(filepath, 'r') as f:
for i, l in enumerate(f.readlines()):
l = l.strip()
lines.append(l)
return lines
def get_splits(lines, line_counts):
all_lines = []
line_idx = []
file_mappings = []
for i, l in enumerate(lines):
all_lines.extend(l)
line_idx.extend(list(range(len(l))))
file_mappings.extend([i]*len(l))
indices = list(range(len(all_lines)))
random.shuffle(indices)
all_lines = [all_lines[idx] for idx in indices]
line_idx = [line_idx[idx] for idx in indices]
file_mappings = [file_mappings[idx] for idx in indices]
splits = []
mappings = []
start = 0
for end in line_counts:
end += start
splits.append(all_lines[start:end])
mappings.append(format_mappings(line_idx[start:end], file_mappings[start:end]))
start = end
return splits, mappings
def format_mappings(line_idx, file_mappings):
lines = []
for m, l in zip(file_mappings, line_idx):
lines.append(str(m).strip()+'\t'+str(l).strip())
return lines
def get_filepaths(filepaths, output_dir):
paths = []
train_path = 'train.json'
dev_path = 'dev.json'
test_path = 'test.json'
paths.append(os.path.join(output_dir, train_path))
paths.append(os.path.join(output_dir, dev_path))
paths.append(os.path.join(output_dir, test_path))
return paths
def write_files(lines, mappings, filepaths):
for l, m, path in zip(lines, mappings, filepaths):
write_file(l, path)
write_mapping_file(m, path)
def write_file(lines, path):
print('Writing:', path)
with open(path, 'w') as f:
for l in lines:
f.write(l+'\n')
def write_mapping_file(m, path):
path = path+'.map'
m = [get_mapping_header()]+m
write_file(m, path)
def get_mapping_header():
return 'file\tline #'
if not os.path.exists(args.output_dir):
os.makedirs(args.output_dir)
lines = []
for filepath in args.input_files:
_lines = get_lines(filepath)
lines.append(_lines)
#calculate number of lines to use for each
line_counts = [len(l) for l in lines]
total_lines = sum(line_counts)
dev_percent = args.test_percent[0]
dev_lines = math.ceil(dev_percent*total_lines)
test_percent = 0
if len(args.test_percent)==2:
test_percent=args.test_percent[1]
test_lines = math.ceil(test_percent*total_lines)
train_lines = total_lines-(test_lines+dev_lines)
normed_lines = [train_lines, dev_lines, test_lines]
normed_lines = [int(l) for l in normed_lines]
splits, mappings = get_splits(lines, normed_lines)
filepaths = get_filepaths(args.input_files, args.output_dir)
print('Writing output to:', filepaths)
write_files(splits, mappings, filepaths)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment