Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
H
homework2_dialog_project
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
20220418012
homework2_dialog_project
Commits
91204af0
Commit
91204af0
authored
2 years ago
by
20220418012
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Upload New File
parent
5059efba
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
230 additions
and
0 deletions
+230
-0
NLG/model/model_multi_input.py
+230
-0
No files found.
NLG/model/model_multi_input.py
0 → 100644
View file @
91204af0
import
random
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
.transformer_module
import
TransformerModule
class
MultiInputModel
(
nn
.
Module
):
def
__init__
(
self
,
config
,
vocab
,
n_segments
=
None
):
super
(
MultiInputModel
,
self
)
.
__init__
()
self
.
config
=
config
self
.
vocab
=
vocab
self
.
transformer_module
=
TransformerModule
(
config
.
n_layers
,
len
(
vocab
),
config
.
n_pos_embeddings
,
config
.
embeddings_size
,
vocab
.
pad_id
,
config
.
n_heads
,
config
.
dropout
,
config
.
embed_dropout
,
config
.
attn_dropout
,
config
.
ff_dropout
,
n_segments
)
self
.
pre_softmax
=
nn
.
Linear
(
config
.
embeddings_size
,
len
(
vocab
),
bias
=
False
)
self
.
pre_softmax
.
weight
=
self
.
transformer_module
.
embeddings
.
weight
def
forward
(
self
,
x
,
contexts
=
[]):
enc_contexts
=
[
self
.
encode
(
c
)
for
c
in
contexts
]
return
self
.
decode
(
x
,
enc_contexts
)
def
encode
(
self
,
x
):
return
self
.
transformer_module
(
x
)
def
generate
(
self
,
enc_x
):
return
self
.
pre_softmax
(
enc_x
)
def
decode
(
self
,
x
,
enc_contexts
=
[]):
x
,
_
=
self
.
transformer_module
(
x
,
enc_contexts
)
return
self
.
generate
(
x
)
def
predict
(
self
,
contexts
=
[]):
enc_contexts
=
[
self
.
encode
(
c
)
for
c
in
contexts
]
prediction
=
self
.
beam_search
(
enc_contexts
)
return
prediction
def
predict_beam
(
self
,
contexts
=
[]):
enc_contexts
=
[
self
.
encode
(
c
)
for
c
in
contexts
]
prediction
=
self
.
beam_search
(
enc_contexts
,
return_beams
=
True
)
return
prediction
def
_length_penalty
(
self
,
sequence_lengths
):
"""https://arxiv.org/abs/1609.08144"""
return
(
5
+
sequence_lengths
)
**
self
.
config
.
length_penalty
/
(
5
+
1
)
**
self
.
config
.
length_penalty
def
predict_next
(
self
,
enc_contexts
=
[],
return_beams
=
False
,
prefix
=
[]):
with
torch
.
no_grad
():
if
len
(
enc_contexts
)
==
0
:
return
[]
batch_size
=
enc_contexts
[
0
][
0
]
.
shape
[
0
]
device
=
next
(
self
.
parameters
())
.
device
ind
=
len
(
prefix
)
if
ind
:
assert
batch_size
==
1
prefix_sentence
=
[
self
.
vocab
.
bos_id
]
+
prefix
prevs
=
torch
.
LongTensor
(
prefix_sentence
)
.
to
(
device
)
prevs
=
prevs
.
expand
(
self
.
config
.
beam_size
,
ind
+
1
)
else
:
prevs
=
torch
.
full
((
batch_size
*
self
.
config
.
beam_size
,
1
),
fill_value
=
self
.
vocab
.
bos_id
,
dtype
=
torch
.
long
,
device
=
device
)
beam_enc_contexts
=
[]
for
c
,
p
in
enc_contexts
:
c
=
c
.
unsqueeze
(
1
)
.
repeat
(
1
,
self
.
config
.
beam_size
,
1
,
1
)
c
=
c
.
view
(
-
1
,
c
.
shape
[
2
],
c
.
shape
[
3
])
p
=
p
.
unsqueeze
(
1
)
.
repeat
(
1
,
self
.
beam_size
,
1
)
p
=
p
.
view
(
-
1
,
p
.
shape
[
2
])
beam_enc_contexts
.
append
((
c
,
p
))
outputs
,
_
=
self
.
transformer_module
(
prevs
,
beam_enc_contexts
)
logits
=
self
.
generate
(
outputs
[:,
-
1
,
:])
probs
=
F
.
softmax
(
logits
,
dim
=-
1
)
return
probs
[
0
]
.
tolist
()
def
beam_search
(
self
,
enc_contexts
=
[],
return_beams
=
False
):
with
torch
.
no_grad
():
if
len
(
enc_contexts
)
==
0
:
return
[]
batch_size
=
enc_contexts
[
0
][
0
]
.
shape
[
0
]
device
=
next
(
self
.
parameters
())
.
device
prevs
=
torch
.
full
((
batch_size
*
self
.
config
.
beam_size
,
1
),
fill_value
=
self
.
vocab
.
bos_id
,
dtype
=
torch
.
long
,
device
=
device
)
beam_scores
=
torch
.
zeros
(
batch_size
,
self
.
config
.
beam_size
,
device
=
device
)
beam_lens
=
torch
.
ones
(
batch_size
,
self
.
config
.
beam_size
,
dtype
=
torch
.
long
,
device
=
device
)
is_end
=
torch
.
zeros
(
batch_size
,
self
.
config
.
beam_size
,
dtype
=
torch
.
uint8
,
device
=
device
)
beam_enc_contexts
=
[]
for
c
,
p
in
enc_contexts
:
c
=
c
.
unsqueeze
(
1
)
.
repeat
(
1
,
self
.
config
.
beam_size
,
1
,
1
)
c
=
c
.
view
(
-
1
,
c
.
shape
[
2
],
c
.
shape
[
3
])
p
=
p
.
unsqueeze
(
1
)
.
repeat
(
1
,
self
.
config
.
beam_size
,
1
)
p
=
p
.
view
(
-
1
,
p
.
shape
[
2
])
beam_enc_contexts
.
append
((
c
,
p
))
current_sample_prob
=
1
group_size
=
self
.
config
.
beam_size
//
self
.
config
.
diversity_groups
diversity_penalty
=
torch
.
zeros
((
batch_size
,
len
(
self
.
vocab
)),
device
=
device
)
repeat
=
[{}
for
i
in
range
(
batch_size
*
self
.
config
.
beam_size
)]
for
i
in
range
(
self
.
config
.
max_seq_len
):
outputs
,
_
=
self
.
transformer_module
(
prevs
,
beam_enc_contexts
)
logits
=
self
.
generate
(
outputs
[:,
-
1
,
:])
log_probs
=
F
.
log_softmax
(
logits
,
dim
=-
1
)
for
idx
in
range
(
batch_size
*
self
.
config
.
beam_size
):
for
key
in
repeat
[
idx
]:
for
value
in
repeat
[
idx
][
key
]:
log_probs
[
idx
][
value
]
=
-
1000
log_probs
=
log_probs
.
view
(
batch_size
,
self
.
config
.
beam_size
,
-
1
)
beam_scores
=
beam_scores
.
unsqueeze
(
-
1
)
+
log_probs
*
(
1
-
is_end
.
float
()
.
unsqueeze
(
-
1
))
ba
,
be
,
dim
=
beam_scores
.
shape
for
ba_idx
in
range
(
ba
):
for
be_idx
in
range
(
be
):
if
int
(
torch
.
max
(
beam_scores
[
ba_idx
][
be_idx
])
==
torch
.
min
(
beam_scores
[
ba_idx
][
be_idx
])):
temp
=
float
(
beam_scores
[
ba_idx
][
be_idx
][
0
])
beam_scores
[
ba_idx
][
be_idx
]
=
-
float
(
'inf'
)
beam_scores
[
ba_idx
][
be_idx
][
0
]
=
temp
penalty
=
self
.
_length_penalty
(
beam_lens
.
float
()
+
1
-
is_end
.
float
())
penalty
=
penalty
.
unsqueeze
(
-
1
)
.
repeat
(
1
,
1
,
len
(
self
.
vocab
))
beam_scores
=
beam_scores
/
penalty
if
i
==
0
:
penalty
=
penalty
[:,
0
,
:]
beam_scores
=
beam_scores
[:,
0
,
:]
beam_scores
,
idxs
=
beam_scores
.
topk
(
self
.
config
.
beam_size
,
dim
=-
1
)
beam_idxs
=
torch
.
zeros
((
batch_size
,
self
.
config
.
beam_size
),
dtype
=
torch
.
long
,
device
=
device
)
else
:
penalty
=
penalty
.
view
(
batch_size
,
self
.
config
.
diversity_groups
,
group_size
,
-
1
)
beam_scores
=
beam_scores
.
view
(
batch_size
,
self
.
config
.
diversity_groups
,
group_size
,
-
1
)
all_scores
,
all_idxs
=
[],
[]
for
g
in
range
(
self
.
config
.
diversity_groups
):
g_beam_scores
=
beam_scores
[:,
g
,
:,
:]
g_penalty
=
penalty
[:,
g
,
:,
:]
g_beam_scores
-=
self
.
config
.
diversity_coef
*
diversity_penalty
.
unsqueeze
(
1
)
/
g_penalty
g_beam_scores
=
g_beam_scores
.
view
(
batch_size
,
-
1
)
if
random
.
random
()
<
current_sample_prob
:
beam_probas
=
F
.
softmax
(
g_beam_scores
/
self
.
config
.
temperature
,
dim
=-
1
)
if
self
.
config
.
annealing_topk
is
not
None
:
beam_probas
,
sample_idxs
=
beam_probas
.
topk
(
self
.
config
.
annealing_topk
,
dim
=-
1
)
g_idxs
=
torch
.
multinomial
(
beam_probas
,
group_size
)
g_idxs
=
torch
.
gather
(
sample_idxs
,
1
,
g_idxs
)
else
:
g_idxs
=
torch
.
multinomial
(
beam_probas
,
group_size
)
else
:
_
,
g_idxs
=
g_beam_scores
.
topk
(
group_size
,
dim
=-
1
)
g_scores
=
torch
.
gather
(
beam_scores
[:,
g
,
:,
:]
.
view
(
batch_size
,
-
1
),
1
,
g_idxs
)
g_idxs
+=
g
*
group_size
*
len
(
self
.
vocab
)
all_scores
.
append
(
g_scores
)
all_idxs
.
append
(
g_idxs
)
diversity_penalty
.
scatter_add_
(
1
,
torch
.
fmod
(
g_idxs
,
len
(
self
.
vocab
)),
torch
.
ones
((
batch_size
,
group_size
),
device
=
device
))
diversity_penalty
.
fill_
(
0
)
penalty
=
penalty
.
view
(
batch_size
,
-
1
)
beam_scores
=
torch
.
cat
(
all_scores
,
dim
=-
1
)
idxs
=
torch
.
cat
(
all_idxs
,
dim
=-
1
)
beam_idxs
=
(
idxs
.
float
()
/
len
(
self
.
vocab
))
.
long
()
penalty
=
torch
.
gather
(
penalty
,
1
,
idxs
)
sym_idxs
=
torch
.
fmod
(
idxs
,
log_probs
.
shape
[
-
1
])
is_end
=
torch
.
gather
(
is_end
,
1
,
beam_idxs
)
.
bool
()
beam_lens
=
torch
.
gather
(
beam_lens
,
1
,
beam_idxs
)
sym_idxs
[
is_end
]
=
self
.
vocab
.
pad_id
beam_lens
[
~
is_end
]
+=
1
is_end
[
sym_idxs
==
self
.
vocab
.
eos_id
]
=
1
sym_idxs
=
sym_idxs
.
view
(
batch_size
*
self
.
config
.
beam_size
,
1
)
prevs
=
prevs
.
view
(
batch_size
,
self
.
config
.
beam_size
,
-
1
)
prevs
=
torch
.
gather
(
prevs
,
1
,
beam_idxs
.
unsqueeze
(
-
1
)
.
repeat
(
1
,
1
,
prevs
.
shape
[
-
1
]))
prevs
=
prevs
.
view
(
batch_size
*
self
.
config
.
beam_size
,
-
1
)
prevs
=
torch
.
cat
([
prevs
,
sym_idxs
],
dim
=
1
)
prevs_list
=
prevs
.
tolist
()
for
b
in
range
(
batch_size
*
self
.
config
.
beam_size
):
b_list
=
prevs_list
[
b
]
if
len
(
b_list
)
>
2
and
b_list
[
-
1
]
!=
self
.
vocab
.
pad_id
and
b_list
[
-
1
]
!=
self
.
vocab
.
eos_id
:
key
=
(
int
(
b_list
[
-
3
]),
int
(
b_list
[
-
2
]))
if
key
in
repeat
[
b
]:
repeat
[
b
][
key
]
.
append
(
int
(
b_list
[
-
1
]))
else
:
repeat
[
b
][
key
]
=
[
int
(
b_list
[
-
1
])]
if
all
(
is_end
.
view
(
-
1
)):
break
beam_scores
*=
penalty
current_sample_prob
*=
self
.
config
.
annealing
predicts
=
[]
result
=
prevs
.
view
(
batch_size
,
self
.
config
.
beam_size
,
-
1
)
if
return_beams
:
bests
=
torch
.
argsort
(
beam_scores
,
dim
=-
1
,
descending
=
True
)
for
i
in
range
(
batch_size
):
temp
=
[]
for
j
in
range
(
self
.
config
.
beam_size
):
best_len
=
beam_lens
[
i
,
bests
[
i
][
j
]]
best_seq
=
result
[
i
,
bests
[
i
][
j
],
1
:
best_len
-
1
]
temp
.
append
(
best_seq
.
tolist
())
predicts
.
append
(
temp
)
return
predicts
if
self
.
config
.
sample
:
probs
=
F
.
softmax
(
beam_scores
,
dim
=-
1
)
bests
=
torch
.
multinomial
(
probs
,
1
)
.
view
(
-
1
)
else
:
bests
=
beam_scores
.
argmax
(
dim
=-
1
)
for
i
in
range
(
batch_size
):
best_len
=
beam_lens
[
i
,
bests
[
i
]]
best_seq
=
result
[
i
,
bests
[
i
],
1
:
best_len
-
1
]
predicts
.
append
(
best_seq
.
tolist
())
return
predicts
This diff is collapsed.
Click to expand it.
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment