1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139
| from torchvision import datasets from torchvision import transforms from torch.utils.data import DataLoader import torch import socket import numpy as np import cv2 import time from threading import Thread
train_new_model = False; should_i_test = False; host = 'localhost'; port = 14159; training_epoch = 13; client = socket.socket(socket.AF_INET, socket.SOCK_STREAM) datas = [] class MNISTModel(torch.nn.Module): def __init__(self): super(MNISTModel, self).__init__() self.linear1 = torch.nn.Linear(784,512); self.linear2 = torch.nn.Linear(512, 256); self.linear3 = torch.nn.Linear(256, 128); self.linear4 = torch.nn.Linear(128, 64); self.linear5 = torch.nn.Linear(64, 32); self.linear6 = torch.nn.Linear(32, 10); self.activate = torch.nn.ReLU();
def forward(self,x): x = x.view(-1,784); x = self.activate(self.linear1(x)); x = self.activate(self.linear2(x)); x = self.activate(self.linear3(x)); x = self.activate(self.linear4(x)); x = self.activate(self.linear5(x)); x = (self.linear6(x)); return x;
def connect(): client.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1) try: client.connect((host, port)) print("模型已成功连接至服务器") except Exception: print("连接失败,三秒后重连") time.sleep(3)
def receive(model,datas): data = client.recv(1024);
if "OK" not in data.decode(): for i in data.decode(): datas.append(np.float32(i)); else: print("数据接收完成") x = torch.from_numpy(np.asarray(datas)).view(1,1,784,512) y_hat = model(x); _, predited = torch.max(y_hat.data, dim=1); for index, y_pred in enumerate(predited.detach().numpy().tolist(), 0): print(y_pred); datas.clear(); client.send(str((predited.detach().numpy())).encode("UTF-8"));
batch_size=64 transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081, )) ])
train_dataset = datasets.MNIST(root='../dataset/mnist', train=True, download=True, transform=transform);
train_loader = DataLoader(train_dataset,shuffle=True,batch_size=batch_size); test_dataset = datasets.MNIST(root='../dataset/mnist', train=False, download=True, transform=transform); test_loader = DataLoader(test_dataset,shuffle=True,batch_size=batch_size);
if train_new_model: model = MNISTModel(); criteria = torch.nn.CrossEntropyLoss(); optimizer = torch.optim.SGD(model.parameters(),lr=0.01,momentum=0.15);
for epoch in range(13): for batch_id,data in enumerate(train_loader,0): x,y=data y_hat = model(x); loss = criteria(y_hat,y); optimizer.zero_grad(); loss.backward(); optimizer.step(); print("epoch:", epoch,"batch:",batch_id, "loss:", loss.item()); torch.save(model,"HandWritingNumberModel.pkl") else: print("加载模型...") model = torch.load("HandWritingNumberModel.pkl"); print("加载模型完成!") if should_i_test: total=0; correct=0; for data in test_loader: x,y = data; y_hat = model(x); _, predited = torch.max(y_hat.data, dim=1); total+=y.size(0) for index,y_pred in enumerate(predited.detach().numpy().tolist(),0): y_true = y.detach().numpy().tolist(); print("真实值",y_true[index],"预测值",y_pred); if y_true[index] == y_pred: correct+=1; print("正确率:",100*correct/total); connect(); while True: try: rec_thread = Thread(target=receive(model=model,datas=datas)); rec_thread.start(); except socket.error: print("模型与服务器断开连接,正在重新连接") connect();
|