Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
H
homework2_dialog_project
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
20220418012
homework2_dialog_project
Commits
020bcac2
Commit
020bcac2
authored
Jul 15, 2022
by
20220418012
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Upload New File
parent
0748d570
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
212 additions
and
0 deletions
+212
-0
NLG/model/transformer_module.py
+212
-0
No files found.
NLG/model/transformer_module.py
0 → 100644
View file @
020bcac2
import
math
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
.utils
import
checkpoint_sequential
class
MultiheadAttention
(
nn
.
Module
):
@classmethod
def
_get_future_mask
(
cls
,
size
,
device
):
if
not
hasattr
(
cls
,
'_future_mask'
)
or
cls
.
_future_mask
.
device
!=
device
or
cls
.
_future_mask
.
shape
<
size
:
cls
.
_future_mask
=
torch
.
triu
(
torch
.
ones
(
size
[
0
],
size
[
1
],
dtype
=
torch
.
uint8
,
device
=
device
),
1
)
mask
=
cls
.
_future_mask
[:
size
[
0
],
:
size
[
1
]]
return
mask
def
__init__
(
self
,
n_features
,
n_heads
,
dropout
):
super
(
MultiheadAttention
,
self
)
.
__init__
()
assert
n_features
%
n_heads
==
0
self
.
n_features
=
n_features
self
.
n_heads
=
n_heads
self
.
qkv_proj
=
nn
.
Linear
(
n_features
,
3
*
n_features
)
self
.
out_proj
=
nn
.
Linear
(
n_features
,
n_features
)
self
.
dropout
=
nn
.
Dropout
(
dropout
)
self
.
_init_weights
()
def
_init_weights
(
self
):
nn
.
init
.
normal_
(
self
.
qkv_proj
.
weight
,
std
=
0.02
)
nn
.
init
.
normal_
(
self
.
out_proj
.
weight
,
std
=
0.02
)
def
_split_heads
(
self
,
x
,
is_key
=
False
):
x
=
x
.
view
(
x
.
shape
[
0
],
x
.
shape
[
1
],
self
.
n_heads
,
self
.
n_features
//
self
.
n_heads
)
x
=
x
.
permute
(
0
,
2
,
3
,
1
)
if
is_key
else
x
.
permute
(
0
,
2
,
1
,
3
)
return
x
def
_attn
(
self
,
q
,
k
,
v
,
apply_future_mask
=
True
,
padding_mask
=
None
):
#######################################################
# TODO: Complete the following function.
# The following code should preforme the self-attention operation.
# q: querie, shape: [bs, seq_len, n_heads, hidden_dim]
# k: keys, shape: [bs, seq_len, hidden_dim, n_heads]
# v: values, shape: [bs, seq_len, n_heads, hidden_dim]
# apply_future_mask: True means we should use uni-directional attention. You can get the future mask from the function `_get_future_mask`
# padding_mask: either 1 or 0, 1 means the corresponding token is pad_token, 0 other wise. The pad_token should not be attended. shape: [bs, seq_len]
#
# This function should return the attention output with the shape of [bs, seq_len, n_heads, hidden_dim]
#
# You can implement this function with the following steps:
# 1. Calculate weights by multiplying q and k
# 2. Pass the post to self.model.encode
# 2. Calculate LM loss based on post representations `batch_lm_loss`
# 3. Append the representation of post to enc_contexts and feed enc_contexts into model.decode
# 4. Calculate sequence to sequence loss based on the decoder outputs `batch_loss`
# (one trick: you can refer to the model evaluation code)
#######################################################
w
=
torch
.
matmul
(
q
,
k
)
/
math
.
sqrt
(
self
.
n_features
//
self
.
n_heads
)
if
apply_future_mask
:
future_mask
=
MultiheadAttention
.
_get_future_mask
(
w
.
shape
[
-
2
:],
w
.
device
)
.
unsqueeze
(
0
)
.
unsqueeze
(
0
)
.
bool
()
w
.
masked_fill_
(
future_mask
,
float
(
'-inf'
))
if
padding_mask
is
not
None
:
w
.
masked_fill_
(
padding_mask
.
unsqueeze
(
1
)
.
unsqueeze
(
2
),
float
(
'-inf'
))
w
=
F
.
softmax
(
w
,
dim
=-
1
)
w
=
self
.
dropout
(
w
)
if
padding_mask
is
not
None
:
w
.
masked_fill_
(
padding_mask
.
all
(
dim
=-
1
)
.
unsequeeze
(
1
)
.
unsqueeze
(
2
)
.
unsqueeze
(
3
),
0
)
out
=
torch
.
matmul
(
w
,
v
)
return
out
def
_merge_heads
(
self
,
x
):
x
=
x
.
permute
(
0
,
2
,
1
,
3
)
.
contiguous
()
x
=
x
.
view
(
x
.
shape
[
0
],
x
.
shape
[
1
],
self
.
n_features
)
return
x
def
forward
(
self
,
query
,
key
,
value
,
padding_mask
):
qkv_same
=
(
query
.
data_ptr
()
==
key
.
data_ptr
()
==
value
.
data_ptr
())
kv_same
=
(
key
.
data_ptr
()
==
value
.
data_ptr
())
if
qkv_same
:
query
,
key
,
value
=
self
.
qkv_proj
(
query
)
.
split
(
self
.
n_features
,
dim
=-
1
)
apply_future_mask
=
True
# self-attention
elif
kv_same
:
q_w
,
q_b
=
self
.
qkv_proj
.
weight
[:
self
.
n_features
,
:],
self
.
qkv_proj
.
bias
[:
self
.
n_features
]
query
=
F
.
linear
(
query
,
q_w
,
q_b
)
kv_w
,
kv_b
=
self
.
qkv_proj
.
weight
[
self
.
n_features
:,
:],
self
.
qkv_proj
.
bias
[
self
.
n_features
:]
key
,
value
=
F
.
linear
(
key
,
kv_w
,
kv_b
)
.
split
(
self
.
n_features
,
dim
=-
1
)
apply_future_mask
=
False
else
:
assert
False
query
=
self
.
_split_heads
(
query
)
key
=
self
.
_split_heads
(
key
,
is_key
=
True
)
value
=
self
.
_split_heads
(
value
)
x
=
self
.
_attn
(
query
,
key
,
value
,
apply_future_mask
,
padding_mask
)
x
=
self
.
_merge_heads
(
x
)
x
=
self
.
out_proj
(
x
)
return
x
class
FeedForward
(
nn
.
Module
):
@staticmethod
def
gelu
(
x
):
return
0.5
*
x
*
(
1
+
torch
.
tanh
(
math
.
sqrt
(
2
/
math
.
pi
)
*
(
x
+
0.044715
*
torch
.
pow
(
x
,
3
))))
def
__init__
(
self
,
in_features
,
middle_features
,
dropout
):
super
(
FeedForward
,
self
)
.
__init__
()
self
.
layer_1
=
nn
.
Linear
(
in_features
,
middle_features
)
self
.
layer_2
=
nn
.
Linear
(
middle_features
,
in_features
)
self
.
dropout
=
nn
.
Dropout
(
dropout
)
self
.
_init_weights
()
def
_init_weights
(
self
):
nn
.
init
.
normal_
(
self
.
layer_1
.
weight
,
std
=
0.02
)
nn
.
init
.
normal_
(
self
.
layer_2
.
weight
,
std
=
0.02
)
def
forward
(
self
,
x
):
x
=
FeedForward
.
gelu
(
self
.
layer_1
(
x
))
x
=
self
.
dropout
(
x
)
x
=
self
.
layer_2
(
x
)
return
x
class
TransformerBlock
(
nn
.
Module
):
def
__init__
(
self
,
n_features
,
n_heads
,
dropout
,
attn_dropout
,
ff_dropout
):
super
(
TransformerBlock
,
self
)
.
__init__
()
self
.
attn
=
MultiheadAttention
(
n_features
,
n_heads
,
attn_dropout
)
self
.
attn_norm
=
nn
.
LayerNorm
(
n_features
)
self
.
ff
=
FeedForward
(
n_features
,
4
*
n_features
,
ff_dropout
)
self
.
ff_norm
=
nn
.
LayerNorm
(
n_features
)
self
.
dropout
=
nn
.
Dropout
(
dropout
)
def
forward
(
self
,
x
,
padding_mask
,
*
contexts
):
'''contexts = [(context1, padding_mask1), ...]'''
inputs
=
(
x
,
padding_mask
)
+
contexts
full_attn
=
0
n_attn
=
len
(
inputs
)
//
2
for
i
in
range
(
0
,
len
(
inputs
),
2
):
c
,
m
=
inputs
[
i
],
inputs
[
i
+
1
]
.
bool
()
a
=
self
.
attn
(
x
,
c
,
c
,
m
)
full_attn
+=
(
a
/
n_attn
)
full_attn
=
self
.
dropout
(
full_attn
)
x
=
self
.
attn_norm
(
x
+
full_attn
)
f
=
self
.
ff
(
x
)
f
=
self
.
dropout
(
f
)
x
=
self
.
ff_norm
(
x
+
f
)
return
(
x
,
padding_mask
)
+
contexts
class
TransformerModule
(
nn
.
Module
):
def
__init__
(
self
,
n_layers
,
n_embeddings
,
n_pos_embeddings
,
embeddings_size
,
padding_idx
,
n_heads
,
dropout
,
embed_dropout
,
attn_dropout
,
ff_dropout
,
n_segments
=
None
):
super
(
TransformerModule
,
self
)
.
__init__
()
self
.
embeddings
=
nn
.
Embedding
(
n_embeddings
,
embeddings_size
,
padding_idx
=
padding_idx
)
self
.
pos_embeddings
=
nn
.
Embedding
(
n_pos_embeddings
+
1
,
embeddings_size
,
padding_idx
=
0
)
self
.
embed_dropout
=
nn
.
Dropout
(
embed_dropout
)
self
.
layers
=
nn
.
ModuleList
(
[
TransformerBlock
(
embeddings_size
,
n_heads
,
dropout
,
attn_dropout
,
ff_dropout
)
for
_
in
range
(
n_layers
)])
self
.
n_segments
=
n_segments
self
.
_init_weights
()
def
_init_weights
(
self
):
nn
.
init
.
normal_
(
self
.
embeddings
.
weight
,
std
=
0.02
)
nn
.
init
.
normal_
(
self
.
pos_embeddings
.
weight
,
std
=
0.02
)
def
forward
(
self
,
x
,
enc_contexts
=
[]):
'''contexts = [(context1, padding_mask1), ...]'''
padding_mask
=
x
.
eq
(
self
.
embeddings
.
padding_idx
)
.
bool
()
positions
=
torch
.
cumsum
(
~
padding_mask
,
dim
=-
1
,
dtype
=
torch
.
long
)
positions
.
masked_fill_
(
padding_mask
,
self
.
pos_embeddings
.
padding_idx
)
x
=
self
.
embeddings
(
x
)
*
math
.
sqrt
(
self
.
embeddings
.
embedding_dim
)
+
self
.
pos_embeddings
(
positions
)
x
=
self
.
embed_dropout
(
x
)
enc_contexts
=
sum
(
enc_contexts
,
())
if
self
.
n_segments
is
not
None
:
# The following part is used to trade compute for memory
padding_mask
=
padding_mask
.
float
()
# fucking checkpoint_sequential
padding_mask
.
requires_grad_
()
# fucking checkpoint_sequential
out
=
checkpoint_sequential
(
self
.
layers
,
self
.
n_segments
,
x
,
padding_mask
,
*
enc_contexts
)
x
=
out
[
0
]
else
:
for
layer
in
self
.
layers
:
out
=
layer
(
x
,
padding_mask
,
*
enc_contexts
)
x
=
out
[
0
]
return
x
,
padding_mask
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment