convert_org.py 1018 Bytes
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23
import torch
from transformers import BertConfig, BertForPreTraining, load_tf_weights_in_bert


def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file, pytorch_dump_path):
    # Initialise PyTorch model
    config = BertConfig.from_json_file(bert_config_file)
    print(f"Building PyTorch model from configuration: {config}")
    model = BertForPreTraining(config)

    # Load weights from tf checkpoint
    load_tf_weights_in_bert(model, config, tf_checkpoint_path)

    # Save pytorch-model
    print(f"Save PyTorch model to {pytorch_dump_path}")
    torch.save(model.state_dict(), pytorch_dump_path)


if __name__ == "__main__":
    tf_checkpoint_path = "data/chinese_roberta_L-4_H-312_A-12/bert_model.ckpt.data-00000-of-00001"
    bert_config_file = "./data/chinese_roberta_L-4_H-312_A-12/bert_config.json"
    pytorch_dump_path = "./data/chinese_roberta_L-4_H-312_A-12/roberta_zh_L-4_H-312_A-12.bin"
    convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file, pytorch_dump_path)