from utils import load_data
import kmodel
import sys
import pickle

# 加载训练数据
X_train, y_train = load_data()

# 创建网络结构
my_model = kmodel.create_model()
with open("net.txt", mode="w") as f:
    orig_stdout = sys.stdout
    sys.stdout = f
    my_model.summary()
    sys.stdout = orig_stdout

# 编译网络模型
kmodel.compile_model(my_model)

# my_model = kmodel.load_trained_model('my_model')
# 训练网络模型
history = kmodel.train_model(my_model, X_train, y_train)
with open("log.pickle", mode="wb") as f:
    pickle.dump(history.history, f)

# 保存网络模型
kmodel.save_model(my_model, 'my_model')