Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
H
homework2_dialog_project
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
20220418012
homework2_dialog_project
Commits
b1b0c4b3
Commit
b1b0c4b3
authored
2 years ago
by
20220418012
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Upload New File
parent
f0d129ab
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
149 additions
and
0 deletions
+149
-0
NLU/optim.py
+149
-0
No files found.
NLU/optim.py
0 → 100644
View file @
b1b0c4b3
# transformer_chatbot
# Copyright (C) 2018 Golovanov, Tselousov
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
import
math
import
torch
class
Adam
(
torch
.
optim
.
Optimizer
):
"""Implements Adam algorithm.
This implementation is modified from torch.optim.Adam based on:
`Fixed Weight Decay Regularization in Adam`
(see https://arxiv.org/abs/1711.05101)
It has been proposed in `Adam: A Method for Stochastic Optimization`_.
Arguments:
params (iterable): iterable of parameters to optimize or dicts defining
parameter groups
lr (float, optional): learning rate (default: 1e-3)
betas (Tuple[float, float], optional): coefficients used for computing
running averages of gradient and its square (default: (0.9, 0.999))
eps (float, optional): term added to the denominator to improve
numerical stability (default: 1e-8)
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
amsgrad (boolean, optional): whether to use the AMSGrad variant of this
algorithm from the paper `On the Convergence of Adam and Beyond`_
.. _Adam
\
: A Method for Stochastic Optimization:
https://arxiv.org/abs/1412.6980
.. _On the Convergence of Adam and Beyond:
https://openreview.net/forum?id=ryQu7f-RZ
"""
def
__init__
(
self
,
params
,
lr
=
1e-3
,
betas
=
(
0.9
,
0.999
),
eps
=
1e-8
,
weight_decay
=
0
,
amsgrad
=
False
):
defaults
=
dict
(
lr
=
lr
,
betas
=
betas
,
eps
=
eps
,
weight_decay
=
weight_decay
,
amsgrad
=
amsgrad
)
super
(
Adam
,
self
)
.
__init__
(
params
,
defaults
)
def
step
(
self
,
closure
=
None
):
"""Performs a single optimization step.
Arguments:
closure (callable, optional): A closure that reevaluates the model
and returns the loss.
"""
loss
=
None
if
closure
is
not
None
:
loss
=
closure
()
for
group
in
self
.
param_groups
:
for
p
in
group
[
'params'
]:
if
p
.
grad
is
None
:
continue
grad
=
p
.
grad
.
data
if
grad
.
is_sparse
:
raise
RuntimeError
(
'Adam does not support sparse gradients, please consider SparseAdam instead'
)
amsgrad
=
group
[
'amsgrad'
]
state
=
self
.
state
[
p
]
# State initialization
if
len
(
state
)
==
0
:
state
[
'step'
]
=
0
# Exponential moving average of gradient values
state
[
'exp_avg'
]
=
torch
.
zeros_like
(
p
.
data
)
# Exponential moving average of squared gradient values
state
[
'exp_avg_sq'
]
=
torch
.
zeros_like
(
p
.
data
)
if
amsgrad
:
# Maintains max of all exp. moving avg. of sq. grad. values
state
[
'max_exp_avg_sq'
]
=
torch
.
zeros_like
(
p
.
data
)
exp_avg
,
exp_avg_sq
=
state
[
'exp_avg'
],
state
[
'exp_avg_sq'
]
if
amsgrad
:
max_exp_avg_sq
=
state
[
'max_exp_avg_sq'
]
beta1
,
beta2
=
group
[
'betas'
]
state
[
'step'
]
+=
1
# Decay the first and second moment running average coefficient
exp_avg
.
mul_
(
beta1
)
.
add_
(
grad
,
alpha
=
1
-
beta1
)
exp_avg_sq
.
mul_
(
beta2
)
.
addcmul_
(
grad
,
grad
,
value
=
1.0
-
beta2
)
if
amsgrad
:
# Maintains the maximum of all 2nd moment running avg. till now
torch
.
max
(
max_exp_avg_sq
,
exp_avg_sq
,
out
=
max_exp_avg_sq
)
# Use the max. for normalizing running avg. of gradient
denom
=
max_exp_avg_sq
.
sqrt
()
.
add_
(
group
[
'eps'
])
else
:
denom
=
exp_avg_sq
.
sqrt
()
.
add_
(
group
[
'eps'
])
bias_correction1
=
1
-
beta1
**
state
[
'step'
]
bias_correction2
=
1
-
beta2
**
state
[
'step'
]
step_size
=
group
[
'lr'
]
*
math
.
sqrt
(
bias_correction2
)
/
bias_correction1
if
group
[
'weight_decay'
]
!=
0
:
p
.
data
.
add_
(
p
.
data
,
alpha
=-
group
[
'weight_decay'
]
*
group
[
'lr'
])
p
.
data
.
addcdiv_
(
exp_avg
,
denom
,
value
=-
step_size
)
return
loss
class
NoamOpt
:
def
__init__
(
self
,
embeddings_size
,
factor
,
warmup
,
optimizer
):
self
.
embeddings_size
=
embeddings_size
self
.
factor
=
factor
self
.
warmup
=
warmup
self
.
optimizer
=
optimizer
self
.
_step
=
1
def
state_dict
(
self
):
return
{
'step'
:
self
.
_step
,
'optimizer'
:
self
.
optimizer
.
state_dict
()}
def
load_state_dict
(
self
,
state_dict
):
self
.
_step
=
state_dict
[
'step'
]
self
.
optimizer
.
load_state_dict
(
state_dict
[
'optimizer'
])
def
zero_grad
(
self
):
return
self
.
optimizer
.
zero_grad
()
@property
def
param_groups
(
self
):
return
self
.
optimizer
.
param_groups
def
step
(
self
):
self
.
_step
+=
1
rate
=
self
.
rate
()
for
p
in
self
.
optimizer
.
param_groups
:
p
[
'lr'
]
=
rate
self
.
optimizer
.
step
()
def
curr_step
(
self
):
return
self
.
_step
def
rate
(
self
,
step
=
None
):
if
step
is
None
:
step
=
self
.
_step
return
self
.
factor
*
(
self
.
embeddings_size
**
(
-
0.5
)
*
min
(
step
**
(
-
0.5
),
step
*
self
.
warmup
**
(
-
1.5
)))
\ No newline at end of file
This diff is collapsed.
Click to expand it.
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment