Commit 710ed1c9 by 20200318029

homework5

parent 78c3c26b
This source diff could not be displayed because it is too large. You can view the blob instead.
#!/usr/bin/env python
# -*- coding: utf-8 -*-
from copy import deepcopy
import re
try:
import psyco
psyco.full()
except:
pass
try:
from zh_wiki import zh2Hant, zh2Hans
except ImportError:
from zhtools.zh_wiki import zh2Hant, zh2Hans
import sys
py3k = sys.version_info >= (3, 0, 0)
if py3k:
UEMPTY = ''
else:
_zh2Hant, _zh2Hans = {}, {}
for old, new in ((zh2Hant, _zh2Hant), (zh2Hans, _zh2Hans)):
for k, v in old.items():
new[k.decode('utf8')] = v.decode('utf8')
zh2Hant = _zh2Hant
zh2Hans = _zh2Hans
UEMPTY = ''.decode('utf8')
# states
(START, END, FAIL, WAIT_TAIL) = list(range(4))
# conditions
(TAIL, ERROR, MATCHED_SWITCH, UNMATCHED_SWITCH, CONNECTOR) = list(range(5))
MAPS = {}
class Node(object):
def __init__(self, from_word, to_word=None, is_tail=True,
have_child=False):
self.from_word = from_word
if to_word is None:
self.to_word = from_word
self.data = (is_tail, have_child, from_word)
self.is_original = True
else:
self.to_word = to_word or from_word
self.data = (is_tail, have_child, to_word)
self.is_original = False
self.is_tail = is_tail
self.have_child = have_child
def is_original_long_word(self):
return self.is_original and len(self.from_word)>1
def is_follow(self, chars):
return chars != self.from_word[:-1]
def __str__(self):
return '<Node, %s, %s, %s, %s>' % (repr(self.from_word),
repr(self.to_word), self.is_tail, self.have_child)
__repr__ = __str__
class ConvertMap(object):
def __init__(self, name, mapping=None):
self.name = name
self._map = {}
if mapping:
self.set_convert_map(mapping)
def set_convert_map(self, mapping):
convert_map = {}
have_child = {}
max_key_length = 0
for key in sorted(mapping.keys()):
if len(key)>1:
for i in range(1, len(key)):
parent_key = key[:i]
have_child[parent_key] = True
have_child[key] = False
max_key_length = max(max_key_length, len(key))
for key in sorted(have_child.keys()):
convert_map[key] = (key in mapping, have_child[key],
mapping.get(key, UEMPTY))
self._map = convert_map
self.max_key_length = max_key_length
def __getitem__(self, k):
try:
is_tail, have_child, to_word = self._map[k]
return Node(k, to_word, is_tail, have_child)
except:
return Node(k)
def __contains__(self, k):
return k in self._map
def __len__(self):
return len(self._map)
class StatesMachineException(Exception): pass
class StatesMachine(object):
def __init__(self):
self.state = START
self.final = UEMPTY
self.len = 0
self.pool = UEMPTY
def clone(self, pool):
new = deepcopy(self)
new.state = WAIT_TAIL
new.pool = pool
return new
def feed(self, char, map):
node = map[self.pool+char]
if node.have_child:
if node.is_tail:
if node.is_original:
cond = UNMATCHED_SWITCH
else:
cond = MATCHED_SWITCH
else:
cond = CONNECTOR
else:
if node.is_tail:
cond = TAIL
else:
cond = ERROR
new = None
if cond == ERROR:
self.state = FAIL
elif cond == TAIL:
if self.state == WAIT_TAIL and node.is_original_long_word():
self.state = FAIL
else:
self.final += node.to_word
self.len += 1
self.pool = UEMPTY
self.state = END
elif self.state == START or self.state == WAIT_TAIL:
if cond == MATCHED_SWITCH:
new = self.clone(node.from_word)
self.final += node.to_word
self.len += 1
self.state = END
self.pool = UEMPTY
elif cond == UNMATCHED_SWITCH or cond == CONNECTOR:
if self.state == START:
new = self.clone(node.from_word)
self.final += node.to_word
self.len += 1
self.state = END
else:
if node.is_follow(self.pool):
self.state = FAIL
else:
self.pool = node.from_word
elif self.state == END:
# END is a new START
self.state = START
new = self.feed(char, map)
elif self.state == FAIL:
raise StatesMachineException('Translate States Machine '
'have error with input data %s' % node)
return new
def __len__(self):
return self.len + 1
def __str__(self):
return '<StatesMachine %s, pool: "%s", state: %s, final: %s>' % (
id(self), self.pool, self.state, self.final)
__repr__ = __str__
class Converter(object):
def __init__(self, to_encoding):
self.to_encoding = to_encoding
self.map = MAPS[to_encoding]
self.start()
def feed(self, char):
branches = []
for fsm in self.machines:
new = fsm.feed(char, self.map)
if new:
branches.append(new)
if branches:
self.machines.extend(branches)
self.machines = [fsm for fsm in self.machines if fsm.state != FAIL]
all_ok = True
for fsm in self.machines:
if fsm.state != END:
all_ok = False
if all_ok:
self._clean()
return self.get_result()
def _clean(self):
if len(self.machines):
self.machines.sort(key=lambda x: len(x))
# self.machines.sort(cmp=lambda x,y: cmp(len(x), len(y)))
self.final += self.machines[0].final
self.machines = [StatesMachine()]
def start(self):
self.machines = [StatesMachine()]
self.final = UEMPTY
def end(self):
self.machines = [fsm for fsm in self.machines
if fsm.state == FAIL or fsm.state == END]
self._clean()
def convert(self, string):
self.start()
for char in string:
self.feed(char)
self.end()
return self.get_result()
def get_result(self):
return self.final
def registery(name, mapping):
global MAPS
MAPS[name] = ConvertMap(name, mapping)
registery('zh-hant', zh2Hant)
registery('zh-hans', zh2Hans)
del zh2Hant, zh2Hans
def run():
import sys
from optparse import OptionParser
parser = OptionParser()
parser.add_option('-e', type='string', dest='encoding',
help='encoding')
parser.add_option('-f', type='string', dest='file_in',
help='input file (- for stdin)')
parser.add_option('-t', type='string', dest='file_out',
help='output file')
(options, args) = parser.parse_args()
if not options.encoding:
parser.error('encoding must be set')
if options.file_in:
if options.file_in == '-':
file_in = sys.stdin
else:
file_in = open(options.file_in)
else:
file_in = sys.stdin
if options.file_out:
if options.file_out == '-':
file_out = sys.stdout
else:
file_out = open(options.file_out, 'wb')
else:
file_out = sys.stdout
c = Converter(options.encoding)
for line in file_in:
# print >> file_out, c.convert(line.rstrip('\n').decode(
file_out.write(c.convert(line.rstrip('\n').decode(
'utf8')).encode('utf8'))
if __name__ == '__main__':
run()
...@@ -85,7 +85,16 @@ ...@@ -85,7 +85,16 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"import pdb"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
...@@ -100,12 +109,23 @@ ...@@ -100,12 +109,23 @@
"\n", "\n",
"from nltk import word_tokenize\n", "from nltk import word_tokenize\n",
"from collections import Counter\n", "from collections import Counter\n",
"from torch.autograd import Variable" "from torch.autograd import Variable\n",
"\n",
"from langconv import *"
] ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"import jieba"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
...@@ -129,7 +149,7 @@ ...@@ -129,7 +149,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": 9,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
...@@ -146,24 +166,30 @@ ...@@ -146,24 +166,30 @@
" np.concatenate([x, [padding] * (ML - len(x))]) if len(x) < ML else x for x in X\n", " np.concatenate([x, [padding] * (ML - len(x))]) if len(x) < ML else x for x in X\n",
" ])\n", " ])\n",
"\n", "\n",
"def cht_to_chs(sent):\n",
" sent = Converter('zh-hans').convert(sent)\n",
" sent.encode('utf-8')\n",
" return sent\n",
"\n", "\n",
"class PrepareData:\n", "class PrepareData:\n",
" def __init__(self, train_file, dev_file):\n", " def __init__(self, train_file, dev_file):\n",
" # 读取数据 并分词\n", " # 读取数据并分词\n",
" self.train_en, self.train_cn = self.load_data(train_file)\n", " self.train_en, self.train_cn = self.load_data(train_file)\n",
" self.dev_en, self.dev_cn = self.load_data(dev_file)\n", " self.dev_en, self.dev_cn = self.load_data(dev_file)\n",
"\n", "\n",
" # 构建单词表\n", " # 构建单词表\n",
" self.en_word_dict, self.en_total_words, self.en_index_dict = self.build_dict(self.train_en)\n", " self.en_word_dict, self.en_total_words, self.en_index_dict = \\\n",
" self.cn_word_dict, self.cn_total_words, self.cn_index_dict = self.build_dict(self.train_cn)\n", " self.build_dict(self.train_en)\n",
" self.cn_word_dict, self.cn_total_words, self.cn_index_dict = \\\n",
" self.build_dict(self.train_cn)\n",
"\n", "\n",
" # id化\n", " # id化\n",
" self.train_en, self.train_cn = self.wordToID(self.train_en, self.train_cn, self.en_word_dict, self.cn_word_dict)\n", " self.train_en, self.train_cn = self.word2id(self.train_en, self.train_cn, self.en_word_dict, self.cn_word_dict)\n",
" self.dev_en, self.dev_cn = self.wordToID(self.dev_en, self.dev_cn, self.en_word_dict, self.cn_word_dict)\n", " self.dev_en, self.dev_cn = self.word2id(self.dev_en, self.dev_cn, self.en_word_dict, self.cn_word_dict)\n",
"\n", "\n",
" # 划分batch + padding + mask\n", " # 划分batch + padding + mask\n",
" self.train_data = self.splitBatch(self.train_en, self.train_cn, BATCH_SIZE)\n", " self.train_data = self.split_batch(self.train_en, self.train_cn, BATCH_SIZE)\n",
" self.dev_data = self.splitBatch(self.dev_en, self.dev_cn, BATCH_SIZE)\n", " self.dev_data = self.split_batch(self.dev_en, self.dev_cn, BATCH_SIZE)\n",
"\n", "\n",
" def load_data(self, path):\n", " def load_data(self, path):\n",
" \"\"\"\n", " \"\"\"\n",
...@@ -175,9 +201,16 @@ ...@@ -175,9 +201,16 @@
" en = []\n", " en = []\n",
" cn = []\n", " cn = []\n",
" # TODO ...\n", " # TODO ...\n",
" \n", " with open(path, mode=\"r\", encoding=\"utf-8\") as f:\n",
" \n", " \n",
" \n", " for line in f.readlines():\n",
" sent_en, sent_cn = line.strip().split(\"\\t\")\n",
" sent_en = sent_en.lower()\n",
" sent_cn = cht_to_chs(sent_cn)\n",
" sent_en = [\"BOS\"] + word_tokenize(sent_en) + [\"EOS\"]\n",
" sent_cn = [\"BOS\"] + [word for word in jieba.cut(sent_cn)] + [\"EOS\"]\n",
" en.append(sent_en)\n",
" cn.append(sent_cn)\n",
"\n", "\n",
" return en, cn\n", " return en, cn\n",
" \n", " \n",
...@@ -197,7 +230,7 @@ ...@@ -197,7 +230,7 @@
" ls = word_count.most_common(max_words)\n", " ls = word_count.most_common(max_words)\n",
" # 统计词典的总词数\n", " # 统计词典的总词数\n",
" total_words = len(ls) + 2\n", " total_words = len(ls) + 2\n",
"\n", " \n",
" word_dict = {w[0]: index + 2 for index, w in enumerate(ls)}\n", " word_dict = {w[0]: index + 2 for index, w in enumerate(ls)}\n",
" word_dict['UNK'] = UNK\n", " word_dict['UNK'] = UNK\n",
" word_dict['PAD'] = PAD\n", " word_dict['PAD'] = PAD\n",
...@@ -206,7 +239,7 @@ ...@@ -206,7 +239,7 @@
"\n", "\n",
" return word_dict, total_words, index_dict\n", " return word_dict, total_words, index_dict\n",
"\n", "\n",
" def wordToID(self, en, cn, en_dict, cn_dict, sort=True):\n", " def word2id(self, en, cn, en_dict, cn_dict, sort=True):\n",
" \"\"\"\n", " \"\"\"\n",
" 该方法可以将翻译前(英文)数据和翻译后(中文)数据的单词列表表示的数据\n", " 该方法可以将翻译前(英文)数据和翻译后(中文)数据的单词列表表示的数据\n",
" 均转为id列表表示的数据\n", " 均转为id列表表示的数据\n",
...@@ -217,9 +250,9 @@ ...@@ -217,9 +250,9 @@
" length = len(en)\n", " length = len(en)\n",
" \n", " \n",
" # TODO: 将翻译前(英文)数据和翻译后(中文)数据都转换为id表示的形式\n", " # TODO: 将翻译前(英文)数据和翻译后(中文)数据都转换为id表示的形式\n",
" out_en_ids = \n", " out_en_ids = [[en_dict.get(word, 0) for word in sent] for sent in en]\n",
" out_cn_ids = \n", " out_cn_ids = [[en_dict.get(word, 0) for word in sent] for sent in cn]\n",
"\n", " \n",
" # 构建一个按照句子长度排序的函数\n", " # 构建一个按照句子长度排序的函数\n",
" def len_argsort(seq):\n", " def len_argsort(seq):\n",
" \"\"\"\n", " \"\"\"\n",
...@@ -234,12 +267,12 @@ ...@@ -234,12 +267,12 @@
" sorted_index = len_argsort(out_en_ids)\n", " sorted_index = len_argsort(out_en_ids)\n",
" \n", " \n",
" # TODO: 对翻译前(英文)数据和翻译后(中文)数据都按此基准进行排序\n", " # TODO: 对翻译前(英文)数据和翻译后(中文)数据都按此基准进行排序\n",
" out_en_ids = \n", " out_en_ids = [out_en_ids[idx] for idx in sorted_index]\n",
" out_cn_ids = \n", " out_cn_ids = [out_cn_ids[idx] for idx in sorted_index]\n",
" \n", " \n",
" return out_en_ids, out_cn_ids\n", " return out_en_ids, out_cn_ids\n",
"\n", "\n",
" def splitBatch(self, en, cn, batch_size, shuffle=True):\n", " def split_batch(self, en, cn, batch_size, shuffle=True):\n",
" \"\"\"\n", " \"\"\"\n",
" 将以单词id列表表示的翻译前(英文)数据和翻译后(中文)数据\n", " 将以单词id列表表示的翻译前(英文)数据和翻译后(中文)数据\n",
" 按照指定的batch_size进行划分\n", " 按照指定的batch_size进行划分\n",
......
...@@ -64,7 +64,8 @@ ...@@ -64,7 +64,8 @@
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {}, "metadata": {},
"source": [ "source": [
"** 以下为一个Transformer Encoder Block结构示意图**\n", "**以下为一个Transformer Encoder Block结构示意图**\n",
"\n",
"> 注意: 为方便查看, 下面各部分的内容分别对应着图中第1, 2, 3, 4个方框的序号:" "> 注意: 为方便查看, 下面各部分的内容分别对应着图中第1, 2, 3, 4个方框的序号:"
] ]
}, },
......
This source diff could not be displayed because it is too large. You can view the blob instead.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment