Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
H
homework2_dialog_project
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
20220418012
homework2_dialog_project
Commits
6b0ec419
Commit
6b0ec419
authored
2 years ago
by
20220418012
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Upload New File
parent
b1b0c4b3
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
204 additions
and
0 deletions
+204
-0
NLU/trainer.py
+204
-0
No files found.
NLU/trainer.py
0 → 100644
View file @
6b0ec419
from
optim
import
Adam
,
NoamOpt
import
torch
import
os
import
torch.nn
as
nn
import
torch.distributed
# import torch._tensor
from
dataset
import
PadBatchSeq
from
torch.utils.data
import
DataLoader
from
torch.utils.tensorboard
import
SummaryWriter
from
tqdm
import
tqdm
class
Trainer
:
def
__init__
(
self
,
args
,
model
,
tokz
,
train_dataset
,
valid_dataset
,
log_dir
,
logger
,
device
=
torch
.
device
(
'cuda'
),
valid_writer
=
None
,
distributed
=
False
):
self
.
config
=
args
self
.
device
=
device
self
.
logger
=
logger
self
.
log_dir
=
log_dir
self
.
tokz
=
tokz
self
.
rank
=
torch
.
distributed
.
get_rank
()
if
distributed
else
-
1
self
.
train_writer
=
SummaryWriter
(
os
.
path
.
join
(
log_dir
,
'train'
))
if
valid_writer
is
None
:
self
.
valid_writer
=
SummaryWriter
(
os
.
path
.
join
(
log_dir
,
'valid'
))
else
:
self
.
valid_writer
=
valid_writer
self
.
model
=
model
.
to
(
device
,
non_blocking
=
True
)
self
.
criterion
=
nn
.
CrossEntropyLoss
(
ignore_index
=
tokz
.
pad_token_id
,
reduction
=
'none'
)
.
to
(
device
)
base_optimizer
=
Adam
(
self
.
model
.
parameters
(),
lr
=
self
.
config
.
lr
,
weight_decay
=
0.01
)
if
hasattr
(
self
.
model
,
'config'
):
self
.
optimizer
=
NoamOpt
(
self
.
model
.
config
.
hidden_size
,
0.1
,
self
.
config
.
lr_warmup
,
base_optimizer
)
else
:
self
.
optimizer
=
NoamOpt
(
self
.
model
.
module
.
config
.
hidden_size
,
0.1
,
self
.
config
.
lr_warmup
,
base_optimizer
)
self
.
train_sampler
=
torch
.
utils
.
data
.
distributed
.
DistributedSampler
(
train_dataset
)
if
distributed
else
torch
.
utils
.
data
.
RandomSampler
(
train_dataset
)
self
.
valid_sampler
=
torch
.
utils
.
data
.
distributed
.
DistributedSampler
(
valid_dataset
)
if
distributed
else
None
self
.
train_dataloader
=
DataLoader
(
train_dataset
,
sampler
=
self
.
train_sampler
,
batch_size
=
self
.
config
.
bs
,
num_workers
=
self
.
config
.
n_jobs
,
pin_memory
=
True
,
collate_fn
=
PadBatchSeq
(
self
.
tokz
.
pad_token_id
))
self
.
valid_dataloader
=
DataLoader
(
valid_dataset
,
sampler
=
self
.
valid_sampler
,
batch_size
=
self
.
config
.
bs
,
num_workers
=
self
.
config
.
n_jobs
,
pin_memory
=
True
,
collate_fn
=
PadBatchSeq
(
self
.
tokz
.
pad_token_id
))
def
state_dict
(
self
):
return
self
.
model
.
state_dict
()
def
load_state_dict
(
self
,
state_dict
):
self
.
model
.
load_state_dict
(
state_dict
)
def
_eval_train
(
self
,
epoch
):
self
.
model
.
train
()
intent_loss
,
slot_loss
,
intent_acc
,
slot_acc
,
step_count
=
0
,
0
,
0
,
0
,
0
total
=
len
(
self
.
train_dataloader
)
if
self
.
rank
in
[
-
1
,
0
]:
TQDM
=
tqdm
(
enumerate
(
self
.
train_dataloader
),
desc
=
'Train (epoch #{})'
.
format
(
epoch
),
dynamic_ncols
=
True
,
total
=
total
)
else
:
TQDM
=
enumerate
(
self
.
train_dataloader
)
for
i
,
data
in
TQDM
:
#######################################################
# TODO: Complete the following function.
# The following code should preform the training of the model
# You can implement this function with the following steps:
# 1. Pass the input to GPU by calling data.to(self.device)
# 2. Forward the input to the model
# 3. Compute the loss (remember to divide the loss with self.config.batch_split to enable gradient accumulation)
# 4. Backward the loss
# 5. Update the parameters
# 6. Evaluate the model (by calling _eval_test) every `self.config.eval_steps` steps
#######################################################
# 1. Pass the input to GPU by calling data.to(self.device)
text
=
data
[
'utt'
]
.
to
(
self
.
device
,
non_blocking
=
True
)
intent_labels
=
data
[
'intent'
]
.
to
(
self
.
device
,
non_blocking
=
True
)
slot_labels
=
data
[
'slot'
]
.
to
(
self
.
device
,
non_blocking
=
True
)
mask
=
data
[
'mask'
]
.
to
(
self
.
device
,
non_blocking
=
True
)
token_type
=
data
[
'token_type'
]
.
to
(
self
.
device
,
non_blocking
=
True
)
# 2. Forward the input to the model
intent_logits
,
slot_logits
=
self
.
model
(
input_ids
=
text
,
attention_mask
=
mask
,
token_type_ids
=
token_type
)
# 3. Compute the loss (remember to divide the loss with self.config.batch_split to enable gradient accumulation)
batch_intent_loss
=
self
.
criterion
(
intent_logits
,
intent_labels
)
.
mean
()
batch_slot_loss
=
self
.
criterion
(
slot_logits
.
view
(
-
1
,
slot_logits
.
shape
[
-
1
]),
slot_labels
.
view
(
-
1
))
.
mean
()
slot_mask
=
1
-
slot_labels
.
eq
(
self
.
tokz
.
pad_token_id
)
.
float
()
batch_slot_loss
=
(
batch_slot_loss
*
slot_mask
.
view
(
-
1
))
.
sum
()
/
slot_mask
.
sum
()
batch_loss
=
batch_intent_loss
+
batch_slot_loss
batch_intent_acc
=
(
torch
.
argmax
(
intent_logits
,
dim
=-
1
)
==
intent_labels
)
.
float
()
.
mean
()
batch_slot_acc
=
(
torch
.
argmax
(
slot_logits
,
dim
=-
1
)
==
slot_labels
)
batch_slot_acc
=
torch
.
sum
(
batch_slot_acc
*
slot_mask
)
/
torch
.
sum
(
slot_mask
)
# 4. Backward the loss
full_loss
=
batch_loss
/
self
.
config
.
batch_split
full_loss
.
backward
()
intent_loss
+=
batch_intent_loss
.
item
()
slot_loss
+=
batch_slot_loss
.
item
()
intent_acc
+=
batch_intent_acc
.
item
()
slot_acc
+=
batch_slot_acc
.
item
()
step_count
+=
1
# 5. Update the parameters
curr_step
=
self
.
optimizer
.
curr_step
()
lr
=
self
.
optimizer
.
param_groups
[
0
][
'lr'
]
if
(
i
+
1
)
%
self
.
config
.
batch_split
==
0
:
self
.
optimizer
.
step
()
self
.
optimizer
.
zero_grad
()
intent_loss
/=
step_count
slot_loss
/=
step_count
intent_acc
/=
step_count
slot_acc
/=
step_count
if
self
.
rank
in
[
-
1
,
0
]:
self
.
train_writer
.
add_scalar
(
'loss/intent_loss'
,
intent_loss
,
curr_step
)
self
.
train_writer
.
add_scalar
(
'loss/slot_loss'
,
slot_loss
,
curr_step
)
self
.
train_writer
.
add_scalar
(
'acc/intent_acc'
,
intent_acc
,
curr_step
)
self
.
train_writer
.
add_scalar
(
'acc/slot_acc'
,
slot_acc
,
curr_step
)
TQDM
.
set_postfix
({
'intent_loss'
:
intent_loss
,
'intent_acc'
:
intent_acc
,
'slot_loss'
:
slot_loss
,
'slot_acc'
:
slot_acc
})
intent_loss
,
slot_loss
,
intent_acc
,
slot_acc
,
step_count
=
0
,
0
,
0
,
0
,
0
# 6. Evaluate the model (by calling _eval_test) every `self.config.eval_steps` steps
if
curr_step
%
self
.
config
.
eval_steps
==
0
:
self
.
_eval_test
(
epoch
=
epoch
,
step
=
curr_step
)
def
_eval_test
(
self
,
epoch
,
step
):
self
.
model
.
eval
()
with
torch
.
no_grad
():
dev_intent_loss
=
torch
.
tensor
(
0.0
,
dtype
=
torch
.
float32
,
device
=
self
.
device
)
dev_slot_loss
=
torch
.
tensor
(
0.0
,
dtype
=
torch
.
float32
,
device
=
self
.
device
)
dev_intent_acc
=
torch
.
tensor
(
0.0
,
dtype
=
torch
.
float32
,
device
=
self
.
device
)
dev_slot_acc
=
torch
.
tensor
(
0.0
,
dtype
=
torch
.
float32
,
device
=
self
.
device
)
count
=
torch
.
tensor
(
0.0
,
dtype
=
torch
.
float32
,
device
=
self
.
device
)
for
data
in
self
.
valid_dataloader
:
text
=
data
[
'utt'
]
.
to
(
self
.
device
,
non_blocking
=
True
)
intent_labels
=
data
[
'intent'
]
.
to
(
self
.
device
,
non_blocking
=
True
)
slot_labels
=
data
[
'slot'
]
.
to
(
self
.
device
,
non_blocking
=
True
)
mask
=
data
[
'mask'
]
.
to
(
self
.
device
,
non_blocking
=
True
)
token_type
=
data
[
'token_type'
]
.
to
(
self
.
device
,
non_blocking
=
True
)
intent_logits
,
slot_logits
=
self
.
model
(
input_ids
=
text
,
attention_mask
=
mask
,
token_type_ids
=
token_type
)
batch_intent_loss
=
self
.
criterion
(
intent_logits
,
intent_labels
)
batch_slot_loss
=
self
.
criterion
(
slot_logits
.
view
(
-
1
,
slot_logits
.
shape
[
-
1
]),
slot_labels
.
view
(
-
1
))
slot_mask
=
1
-
slot_labels
.
eq
(
self
.
tokz
.
pad_token_id
)
.
float
()
batch_slot_loss
=
(
batch_slot_loss
*
slot_mask
.
view
(
-
1
))
.
view
(
text
.
shape
[
0
],
-
1
)
.
sum
(
dim
=-
1
)
/
slot_mask
.
sum
(
dim
=-
1
)
dev_intent_loss
+=
batch_intent_loss
.
sum
()
dev_slot_loss
+=
batch_slot_loss
.
sum
()
batch_intent_acc
=
(
torch
.
argmax
(
intent_logits
,
dim
=-
1
)
==
intent_labels
)
.
sum
()
batch_slot_acc
=
(
torch
.
argmax
(
slot_logits
,
dim
=-
1
)
==
slot_labels
)
batch_slot_acc
=
torch
.
sum
(
batch_slot_acc
*
slot_mask
,
dim
=-
1
)
/
torch
.
sum
(
slot_mask
,
dim
=-
1
)
dev_intent_acc
+=
batch_intent_acc
dev_slot_acc
+=
batch_slot_acc
.
sum
()
count
+=
text
.
shape
[
0
]
if
self
.
rank
!=
-
1
:
torch
.
distributed
.
all_reduce
(
dev_intent_loss
,
op
=
torch
.
distributed
.
reduce_op
.
SUM
)
torch
.
distributed
.
all_reduce
(
dev_slot_loss
,
op
=
torch
.
distributed
.
reduce_op
.
SUM
)
torch
.
distributed
.
all_reduce
(
dev_intent_acc
,
op
=
torch
.
distributed
.
reduce_op
.
SUM
)
torch
.
distributed
.
all_reduce
(
dev_slot_acc
,
op
=
torch
.
distributed
.
reduce_op
.
SUM
)
torch
.
distributed
.
all_reduce
(
count
,
op
=
torch
.
distributed
.
reduce_op
.
SUM
)
dev_intent_loss
/=
count
dev_slot_loss
/=
count
dev_intent_acc
/=
count
dev_slot_acc
/=
count
if
self
.
rank
in
[
-
1
,
0
]:
self
.
valid_writer
.
add_scalar
(
'loss/intent_loss'
,
dev_intent_loss
,
step
)
self
.
valid_writer
.
add_scalar
(
'loss/slot_loss'
,
dev_slot_loss
,
step
)
self
.
valid_writer
.
add_scalar
(
'acc/intent_acc'
,
dev_intent_acc
,
step
)
self
.
valid_writer
.
add_scalar
(
'acc/slot_acc'
,
dev_slot_acc
,
step
)
log_str
=
'epoch {:>3}, step {}'
.
format
(
epoch
,
step
)
log_str
+=
', dev_intent_loss {:>4.4f}'
.
format
(
dev_intent_loss
)
log_str
+=
', dev_slot_loss {:>4.4f}'
.
format
(
dev_slot_loss
)
log_str
+=
', dev_intent_acc {:>4.4f}'
.
format
(
dev_intent_acc
)
log_str
+=
', dev_slot_acc {:>4.4f}'
.
format
(
dev_slot_acc
)
self
.
logger
.
info
(
log_str
)
self
.
model
.
train
()
def
train
(
self
,
start_epoch
,
epochs
,
after_epoch_funcs
=
[],
after_step_funcs
=
[]):
for
epoch
in
range
(
start_epoch
+
1
,
epochs
):
self
.
logger
.
info
(
'Training on epoch'
.
format
(
epoch
))
if
hasattr
(
self
.
train_sampler
,
'set_epoch'
):
self
.
train_sampler
.
set_epoch
(
epoch
)
self
.
_eval_train
(
epoch
)
for
func
in
after_epoch_funcs
:
func
(
epoch
,
self
.
device
)
\ No newline at end of file
This diff is collapsed.
Click to expand it.
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment