Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
H
homework2_dialog_project
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
20220418012
homework2_dialog_project
Commits
27cd3f95
Commit
27cd3f95
authored
2 years ago
by
20220418012
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Upload New File
parent
5bbdec70
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
37 additions
and
0 deletions
+37
-0
NLG/model/loss.py
+37
-0
No files found.
NLG/model/loss.py
0 → 100644
View file @
27cd3f95
import
torch
import
torch.nn
as
nn
class
LabelSmoothingLoss
(
nn
.
Module
):
def
__init__
(
self
,
n_labels
,
smoothing
=
0.0
,
ignore_index
=-
100
):
super
(
LabelSmoothingLoss
,
self
)
.
__init__
()
assert
0
<=
smoothing
<=
1
self
.
ignore_index
=
ignore_index
self
.
confidence
=
1
-
smoothing
if
smoothing
>
0
:
self
.
criterion
=
nn
.
KLDivLoss
(
reduction
=
'batchmean'
)
n_ignore_idxs
=
1
+
(
ignore_index
>=
0
)
# 1 for golden truth, later one for ignore_index
one_hot
=
torch
.
full
((
1
,
n_labels
),
fill_value
=
(
smoothing
/
(
n_labels
-
n_ignore_idxs
)))
if
ignore_index
>=
0
:
one_hot
[
0
,
ignore_index
]
=
0
self
.
register_buffer
(
'one_hot'
,
one_hot
)
else
:
self
.
criterion
=
nn
.
NLLLoss
(
reduction
=
'mean'
,
ignore_index
=
ignore_index
)
def
forward
(
self
,
log_inputs
,
targets
):
if
self
.
confidence
<
1
:
tdata
=
targets
.
data
tmp
=
self
.
one_hot
.
repeat
(
targets
.
shape
[
0
],
1
)
tmp
.
scatter_
(
1
,
tdata
.
unsqueeze
(
1
),
self
.
confidence
)
if
self
.
ignore_index
>=
0
:
mask
=
torch
.
nonzero
(
tdata
.
eq
(
self
.
ignore_index
))
.
squeeze
(
-
1
)
if
mask
.
numel
()
>
0
:
tmp
.
index_fill_
(
0
,
mask
,
0
)
targets
=
tmp
return
self
.
criterion
(
log_inputs
,
targets
)
This diff is collapsed.
Click to expand it.
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment