Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
H
homework2_dialog_project
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
20220418012
homework2_dialog_project
Commits
5bbdec70
Commit
5bbdec70
authored
Jul 15, 2022
by
20220418012
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Upload New File
parent
b26a7f86
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
163 additions
and
0 deletions
+163
-0
NLG/model/utils.py
+163
-0
No files found.
NLG/model/utils.py
0 → 100644
View file @
5bbdec70
import
re
import
os
import
json
import
random
import
torch
import
logging
import
numpy
as
np
import
argparse
from
scipy.interpolate
import
RectBivariateSpline
from
torch.utils.checkpoint
import
checkpoint
from
collections
import
namedtuple
,
Counter
from
attrdict
import
AttrDict
def
get_logger
(
filename
,
print2screen
=
True
):
logger
=
logging
.
getLogger
(
filename
)
logger
.
setLevel
(
logging
.
INFO
)
fh
=
logging
.
FileHandler
(
filename
)
fh
.
setLevel
(
logging
.
INFO
)
ch
=
logging
.
StreamHandler
()
ch
.
setLevel
(
logging
.
INFO
)
formatter
=
logging
.
Formatter
(
'[
%(asctime)
s][
%(thread)
d][
%(filename)
s][line:
%(lineno)
d][
%(levelname)
s]
\
>>
%(message)
s'
)
fh
.
setFormatter
(
formatter
)
ch
.
setFormatter
(
formatter
)
logger
.
addHandler
(
fh
)
if
print2screen
:
logger
.
addHandler
(
ch
)
return
logger
def
str2bool
(
v
):
if
v
.
lower
()
in
(
'yes'
,
'true'
,
't'
,
'y'
,
'1'
):
return
True
elif
v
.
lower
()
in
(
'no'
,
'false'
,
'f'
,
'n'
,
'0'
):
return
False
else
:
raise
argparse
.
ArgumentTypeError
(
'Unsupported value encountered.'
)
def
load_config
(
config_file
):
with
open
(
config_file
)
as
f
:
config
=
json
.
load
(
f
)
return
AttrDict
(
config
)
def
set_seed
(
seed
):
torch
.
manual_seed
(
seed
)
torch
.
cuda
.
manual_seed
(
seed
)
random
.
seed
(
seed
)
def
pad_sequence
(
sequences
,
batch_first
=
False
,
padding_value
=
0
):
# assuming trailing dimensions and type of all the Tensors
# in sequences are same and fetching those from sequences[0]
max_size
=
sequences
[
0
]
.
size
()
trailing_dims
=
max_size
[
1
:]
max_len
=
max
([
s
.
size
(
0
)
for
s
in
sequences
])
if
batch_first
:
out_dims
=
(
len
(
sequences
),
max_len
)
+
trailing_dims
else
:
out_dims
=
(
max_len
,
len
(
sequences
))
+
trailing_dims
out_tensor
=
sequences
[
0
]
.
data
.
new
(
*
out_dims
)
.
fill_
(
padding_value
)
for
i
,
tensor
in
enumerate
(
sequences
):
length
=
tensor
.
size
(
0
)
# use index notation to prevent duplicate references to the tensor
if
batch_first
:
out_tensor
[
i
,
:
length
,
...
]
=
tensor
else
:
out_tensor
[:
length
,
i
,
...
]
=
tensor
return
out_tensor
def
checkpoint_sequential
(
functions
,
segments
,
*
inputs
):
def
run_function
(
start
,
end
,
functions
):
def
forward
(
*
inputs
):
for
j
in
range
(
start
,
end
+
1
):
inputs
=
functions
[
j
](
*
inputs
)
return
inputs
return
forward
if
isinstance
(
functions
,
torch
.
nn
.
Sequential
):
functions
=
list
(
functions
.
children
())
segment_size
=
len
(
functions
)
//
segments
# the last chunk has to be non-volatile
end
=
-
1
for
start
in
range
(
0
,
segment_size
*
(
segments
-
1
),
segment_size
):
end
=
start
+
segment_size
-
1
inputs
=
checkpoint
(
run_function
(
start
,
end
,
functions
),
*
inputs
)
if
not
isinstance
(
inputs
,
tuple
):
inputs
=
(
inputs
,)
return
run_function
(
end
+
1
,
len
(
functions
)
-
1
,
functions
)(
*
inputs
)
def
get_latest_ckpt
(
dir_name
):
files
=
[
i
for
i
in
os
.
listdir
(
dir_name
)
if
'.ckpt'
in
i
]
if
len
(
files
)
==
0
:
return
None
else
:
res
=
''
num
=
-
1
for
i
in
files
:
n
=
int
(
i
.
split
(
'-'
)[
-
1
]
.
split
(
'.'
)[
0
])
if
n
>
num
:
num
=
n
res
=
i
return
res
def
get_epoch_from_ckpt
(
ckpt
):
return
int
(
ckpt
.
split
(
'-'
)[
-
1
]
.
split
(
'.'
)[
0
])
def
get_ckpt_filename
(
name
,
epoch
):
return
'{}-{}.ckpt'
.
format
(
name
,
epoch
)
def
f1_score
(
predictions
,
targets
,
average
=
True
):
def
f1_score_items
(
pred_items
,
gold_items
):
common
=
Counter
(
gold_items
)
&
Counter
(
pred_items
)
num_same
=
sum
(
common
.
values
())
if
num_same
==
0
:
return
0
precision
=
num_same
/
len
(
pred_items
)
recall
=
num_same
/
len
(
gold_items
)
f1
=
(
2
*
precision
*
recall
)
/
(
precision
+
recall
)
return
f1
scores
=
[
f1_score_items
(
p
,
t
)
for
p
,
t
in
zip
(
predictions
,
targets
)]
if
average
:
return
sum
(
scores
)
/
len
(
scores
)
return
scores
def
openai_transformer_config
():
class
dotdict
(
dict
):
__getattr__
=
dict
.
get
__setattr__
=
dict
.
__setitem__
__delattr__
=
dict
.
__delitem__
cfg
=
dotdict
({
'n_layers'
:
12
,
'n_embeddings'
:
40477
,
'n_pos_embeddings'
:
512
,
'embeddings_size'
:
768
,
'n_heads'
:
12
,
'dropout'
:
0.1
,
'embed_dropout'
:
0.1
,
'attn_dropout'
:
0.1
,
'ff_dropout'
:
0.1
})
return
cfg
def
load_openai_weights_chinese
(
model
,
directory
):
openai_model
=
torch
.
load
(
directory
)
openai_model
.
pop
(
'decoder.pre_softmax.weight'
)
b
=
list
(
openai_model
.
keys
())
for
i
in
b
:
openai_model
[
'decoder.'
+
i
]
=
openai_model
.
pop
(
i
)
model
.
load_state_dict
(
openai_model
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment