importtorchfromtransformersimportBertConfig,BertForPreTraining,load_tf_weights_in_bertdefconvert_tf_checkpoint_to_pytorch(tf_checkpoint_path,bert_config_file,pytorch_dump_path):# Initialise PyTorch modelconfig=BertConfig.from_json_file(bert_config_file)print(f"Building PyTorch model from configuration: {config}")model=BertForPreTraining(config)# Load weights from tf checkpointload_tf_weights_in_bert(model,config,tf_checkpoint_path)# Save pytorch-modelprint(f"Save PyTorch model to {pytorch_dump_path}")torch.save(model.state_dict(),pytorch_dump_path)if__name__=="__main__":tf_checkpoint_path="data/chinese_roberta_L-4_H-312_A-12/bert_model.ckpt.data-00000-of-00001"bert_config_file="./data/chinese_roberta_L-4_H-312_A-12/bert_config.json"pytorch_dump_path="./data/chinese_roberta_L-4_H-312_A-12/roberta_zh_L-4_H-312_A-12.bin"convert_tf_checkpoint_to_pytorch(tf_checkpoint_path,bert_config_file,pytorch_dump_path)