Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
H
homework2_dialog_project
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
20220418012
homework2_dialog_project
Commits
6f779620
Commit
6f779620
authored
2 years ago
by
20220418012
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Upload New File
parent
50c86fd9
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
62 additions
and
0 deletions
+62
-0
NLU/NLU_model.py
+62
-0
No files found.
NLU/NLU_model.py
0 → 100644
View file @
6f779620
from
transformers
import
BertPreTrainedModel
,
BertModel
from
torch
import
nn
class
NLUModule
(
BertPreTrainedModel
):
def
__init__
(
self
,
config
):
super
()
.
__init__
(
config
)
self
.
num_intent_labels
=
config
.
num_intent_labels
self
.
num_slot_labels
=
config
.
num_slot_labels
self
.
bert
=
BertModel
(
config
)
self
.
dropout
=
nn
.
Dropout
(
config
.
hidden_dropout_prob
)
self
.
intent_classifier
=
nn
.
Linear
(
config
.
hidden_size
,
config
.
num_intent_labels
)
self
.
slot_classifier
=
nn
.
Linear
(
config
.
hidden_size
,
config
.
num_slot_labels
)
self
.
init_weights
()
def
forward
(
self
,
input_ids
=
None
,
attention_mask
=
None
,
token_type_ids
=
None
,
position_ids
=
None
,
head_mask
=
None
,
inputs_embeds
=
None
,
output_attentions
=
None
,
output_hidden_states
=
None
,
):
r"""
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
Labels for computing the sequence classification/regression loss.
Indices should be in :obj:`[0, ..., config.num_labels - 1]`.
If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss),
If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
"""
#######################################################
# TODO: Complete the following function.
# The following function should return the logits of intent and slot classification.
# You can implement this function with the following steps:
# 1. Forward the input to BERT model
# 2. Extract the representation of the whole sentence and each tokens
# 3. Feed the representation of the whole sentence and each tokens to the corresponding classifier
#######################################################
outputs
=
self
.
bert
(
input_ids
=
input_ids
,
attention_mask
=
attention_mask
,
token_type_ids
=
token_type_ids
,
position_ids
=
position_ids
,
head_mask
=
head_mask
,
inputs_embeds
=
inputs_embeds
,
output_attentions
=
output_attentions
,
output_hidden_states
=
output_hidden_states
,
)
seq_encoding
=
outputs
[
0
]
pooled_output
=
outputs
[
1
]
pooled_output
=
self
.
dropout
(
pooled_output
)
intent_logits
=
self
.
intent_classifier
(
pooled_output
)
slot_logits
=
self
.
slot_classifier
(
seq_encoding
)
return
intent_logits
,
slot_logits
This diff is collapsed.
Click to expand it.
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment