Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
H
homework2_dialog_project
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
20220418012
homework2_dialog_project
Commits
5aac183c
Commit
5aac183c
authored
2 years ago
by
20220418012
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Upload New File
parent
eba021df
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
69 additions
and
0 deletions
+69
-0
NLG/interact.py
+69
-0
No files found.
NLG/interact.py
0 → 100644
View file @
5aac183c
import
os
import
torch
import
random
import
traceback
import
model.utils
as
utils
import
model.dataset
as
dataset
from
model.model_multi_input
import
MultiInputModel
from
model.trainer_multi_input
import
Trainer
from
model.text
import
Vocab
import
argparse
class
mylog
:
def
info
(
self
,
text
):
print
(
text
)
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'--config'
,
help
=
'config file'
,
default
=
'config.json'
)
parser
.
add_argument
(
'--gpu'
,
help
=
'which gpu to use'
,
type
=
str
,
default
=
'3'
)
parser
.
add_argument
(
'--epoch'
,
help
=
'which epoch to use'
,
type
=
int
,
default
=-
1
)
args
=
parser
.
parse_args
()
config
=
utils
.
load_config
(
args
.
config
)
config_path
=
os
.
path
.
dirname
(
args
.
config
)
train_dir
=
os
.
path
.
join
(
config_path
,
config
[
'train_dir'
])
os
.
environ
[
'CUDA_VISIBLE_DEVICES'
]
=
args
.
gpu
try
:
print
(
'pytorch version: {}'
.
format
(
torch
.
__version__
))
if
args
.
epoch
==
-
1
:
model_path
=
os
.
path
.
join
(
train_dir
,
utils
.
get_latest_ckpt
(
train_dir
))
else
:
model_path
=
os
.
path
.
join
(
train_dir
,
utils
.
get_ckpt_filename
(
'model'
,
args
.
epoch
))
if
not
os
.
path
.
isfile
(
model_path
):
print
(
'cannot find {}'
.
format
(
model_path
))
exit
(
0
)
if
len
(
args
.
gpu
)
!=
0
:
device
=
torch
.
device
(
"cuda"
)
else
:
device
=
torch
.
device
(
"cpu"
)
vocab
=
Vocab
(
config
.
vocab_path
)
print
(
'Building models'
)
model
=
MultiInputModel
(
config
,
vocab
)
.
to
(
device
)
print
(
'Loading weights from {}'
.
format
(
model_path
))
state_dict
=
torch
.
load
(
model_path
,
map_location
=
device
)[
'model'
]
for
i
in
list
(
state_dict
.
keys
()):
state_dict
[
i
.
replace
(
'.module.'
,
'.'
)]
=
state_dict
.
pop
(
i
)
model
.
load_state_dict
(
state_dict
)
model
.
eval
()
while
True
:
post
=
input
(
'>> '
)
post
=
' '
.
join
(
list
(
post
.
replace
(
' '
,
''
)))
# print('post_str', post)
post
=
[
vocab
.
eos_id
]
+
vocab
.
string2ids
(
post
)
+
[
vocab
.
eos_id
]
# print('post', post)
contexts
=
[
torch
.
tensor
([
post
],
dtype
=
torch
.
long
,
device
=
device
)]
# print('contexts', contexts)
prediction
=
model
.
predict
(
contexts
)[
0
]
pred_str
=
vocab
.
ids2string
(
prediction
)
print
(
'>> {}'
.
format
(
pred_str
))
except
:
print
(
traceback
.
format_exc
())
This diff is collapsed.
Click to expand it.
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment