Commit 50c86fd9 by 20220418012

Upload New File

parent 77d2abdb
import os
import json
import random
import torch
import logging
import argparse
from torch.utils.checkpoint import checkpoint
from attrdict import AttrDict
def get_logger(filename, print2screen=True):
logger = logging.getLogger(filename)
logger.setLevel(logging.INFO)
fh = logging.FileHandler(filename)
fh.setLevel(logging.INFO)
ch = logging.StreamHandler()
ch.setLevel(logging.INFO)
formatter = logging.Formatter('[%(asctime)s][%(thread)d][%(filename)s][line: %(lineno)d][%(levelname)s] \
>> %(message)s')
fh.setFormatter(formatter)
ch.setFormatter(formatter)
logger.addHandler(fh)
if print2screen:
logger.addHandler(ch)
return logger
def load_vocab(vocab_file):
with open(vocab_file) as f:
res = [i.strip().lower() for i in f.readlines() if len(i.strip()) != 0]
return res, dict(zip(res, range(len(res)))), dict(zip(range(len(res)), res)) # list, token2index, index2token
def str2bool(v):
if v.lower() in ('yes', 'true', 't', 'y', '1'):
return True
elif v.lower() in ('no', 'false', 'f', 'n', '0'):
return False
else:
raise argparse.ArgumentTypeError('Unsupported value encountered.')
def load_config(config_file):
with open(config_file) as f:
config = json.load(f)
return AttrDict(config)
def set_seed(seed):
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
random.seed(seed)
def pad_sequence(sequences, batch_first=False, padding_value=0):
# assuming trailing dimensions and type of all the Tensors
# in sequences are same and fetching those from sequences[0]
max_size = sequences[0].size()
trailing_dims = max_size[1:]
max_len = max([s.size(0) for s in sequences])
if batch_first:
out_dims = (len(sequences), max_len) + trailing_dims
else:
out_dims = (max_len, len(sequences)) + trailing_dims
out_tensor = sequences[0].data.new(*out_dims).fill_(padding_value)
for i, tensor in enumerate(sequences):
length = tensor.size(0)
# use index notation to prevent duplicate references to the tensor
if batch_first:
out_tensor[i, :length, ...] = tensor
else:
out_tensor[:length, i, ...] = tensor
return out_tensor
def checkpoint_sequential(functions, segments, *inputs):
def run_function(start, end, functions):
def forward(*inputs):
for j in range(start, end + 1):
inputs = functions[j](*inputs)
return inputs
return forward
if isinstance(functions, torch.nn.Sequential):
functions = list(functions.children())
segment_size = len(functions) // segments
# the last chunk has to be non-volatile
end = -1
for start in range(0, segment_size * (segments - 1), segment_size):
end = start + segment_size - 1
inputs = checkpoint(run_function(start, end, functions), *inputs)
if not isinstance(inputs, tuple):
inputs = (inputs,)
return run_function(end + 1, len(functions) - 1, functions)(*inputs)
def get_latest_ckpt(dir_name):
files = [i for i in os.listdir(dir_name) if '.ckpt' in i]
if len(files) == 0:
return None
else:
res = ''
num = -1
for i in files:
try:
n = int(i.split('-')[-1].split('.')[0])
if n > num:
num = n
res = i
except ValueError:
pass
return res
def get_epoch_from_ckpt(ckpt):
return int(ckpt.split('-')[-1].split('.')[0])
def get_ckpt_filename(name, epoch):
return '{}-{}.ckpt'.format(name, epoch)
def get_ckpt_step_filename(name, step):
return '{}-{}-step.ckpt'.format(name, step)
def f1_score(predictions, targets, average=True):
def f1_score_items(pred_items, gold_items):
common = Counter(gold_items) & Counter(pred_items)
num_same = sum(common.values())
if num_same == 0:
return 0
precision = num_same / len(pred_items)
recall = num_same / len(gold_items)
f1 = (2 * precision * recall) / (precision + recall)
return f1
scores = [f1_score_items(p, t) for p, t in zip(predictions, targets)]
if average:
return sum(scores) / len(scores)
return scores
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment