上线 PyTorch MNIST 模型¶
在本教程中,我们将提供 PyTorch MNIST 模型。
它接收 Base64 编码的图像作为请求数据,并在响应中返回预测。
准备工作¶
访问 PyTorch 示例 - MNIST,下载文件。
运行以下命令来安装和训练模型:
pip install -r requirements.txt
python main.py --save-model
训练完成后,您将拥有如下文件夹结构。创建了一个 mnist_cnn.pt
文件
.
├── README.md
├── main.py
├── mnist_cnn.pt
└── requirements.txt
部署方法¶
有两种方法可以部署模型。
- 直接注册一个函数。
- 仅使用附加处理程序 Handler 注册模型路径。
我们将在本教程中逐步介绍这两种方法。
直接注册一个函数¶
创建应用程序¶
让我们在同一个文件夹中创建一个文件 func_app.py
。
func_app.py | |
---|---|
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 |
|
- 确保您可以导入网络模型。
- 预处理转换代码。
- 示例脚本只保存
state_dict
。这里我们需要初始化模型并加载state_dict
。 - 准备好,3、2、1。GO!
启动服务¶
$ uvicorn func_app:service --reload
INFO: Uvicorn running on http://127.0.0.1:8000 (Press CTRL+C to quit)
INFO: Started reloader process [xxxxx] using statreload
INFO: Started server process [xxxxx]
INFO: Waiting for application startup.
INFO: Application startup complete.
$ pinfer func_app:service --reload
Pinferencia: Frontend component streamlit is starting...
Pinferencia: Backend component uvicorn is starting...
测试服务¶
测试数据那里来?
因为我们的输入是 base64 编码的 MNIST 图像,我们从哪里可以获得这些数据?
您可以使用 PyTorch 的数据集。在同一文件夹中创建一个文件名为 get-base64-img.oy
。
import base64
import random
from io import BytesIO
from PIL import Image
from torchvision import datasets
dataset = datasets.MNIST( # (1)
"./data",
train=True,
download=True,
transform=None,
)
index = random.randint(0, len(dataset.data)) # (2)
img = dataset.data[index]
img = Image.fromarray(img.numpy(), mode="L")
buffered = BytesIO()
img.save(buffered, format="JPEG")
base64_img_str = base64.b64encode(buffered.getvalue()).decode()
print("Base64 String:", base64_img_str) # (3)
print("target:", dataset.targets[index].tolist())
- 这是训练期间使用的 MNIST 数据集。
- 让我们使用随机图像。
- 字符串和目标被打印到标准输出。
运行脚本并复制字符串。
python get-base64-img.py
输出:
Base64 String: /9j/4AAQSkZJRgABAQAAAQABAAD/2wBDAAgGBgcGBQgHBwcJCQgKDBQNDAsLDBkSEw8UHRofHh0aHBwgJC4nICIsIxwcKDcpLDAxNDQ0Hyc5PTgyPC4zNDL/wAALCAAcABwBAREA/8QAHwAAAQUBAQEBAQEAAAAAAAAAAAECAwQFBgcICQoL/8QAtRAAAgEDAwIEAwUFBAQAAAF9AQIDAAQRBRIhMUEGE1FhByJxFDKBkaEII0KxwRVS0fAkM2JyggkKFhcYGRolJicoKSo0NTY3ODk6Q0RFRkdISUpTVFVWV1hZWmNkZWZnaGlqc3R1dnd4eXqDhIWGh4iJipKTlJWWl5iZmqKjpKWmp6ipqrKztLW2t7i5usLDxMXGx8jJytLT1NXW19jZ2uHi4+Tl5ufo6erx8vP09fb3+Pn6/9oACAEBAAA/APn+uhfwXqy2Ph25VYnPiB3SzhUkPlXCfNkAAEsCCCeOeKx9RsLjStUu9Ou1C3NpM8Eqg5AdSVIz35FVqK9xl0HXhb/C20sdMubjTLMQXs11AhkRXmmDsCwzgAYPpz+XI/GrSLrTfiVqNzPapbw3xE8AWQNvUAKXOOmWVjg+teeUV2fgXxd4hsPE2hWEGuX8Vh9uhja3Fw3lbGcBhtzjGCad8XI7iL4p68twHDGcMm45+QqCuPbBFcVRRU97fXepXb3d9dT3VzJjfNPIXdsAAZY8nAAH4VBX/9k=
target: 4
前端界面¶
打开http://127.0.0.1:8501,会自动选择模板Image to Text
。
使用下图:
你会得到:
后端API¶
让我们创建一个文件test.py
test.py | |
---|---|
1 2 3 4 5 6 7 |
|
运行测试:
$ python test.py
Prediction: 4
您可以尝试使用更多图像来测试,甚至可以使用交互式 API 文档页面 http://127.0.0.1:8000
使用 Handler 注册模型路径¶
创建应用程序¶
让我们在同一个文件夹中创建一个文件 func_app.py 。
下面的代码被重构为 MNISTHandler 。看起来更干净!
path_app.py | |
---|---|
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 |
|
-
我们将加载模型的代码移到
load_model
函数中。模型路径可以通过self.model_path
访问。 -
我们将预测代码移到
predict
函数中。该模型可以通过self.model
访问。 -
model_dir
是Pinferencia
查找模型文件的地方。将 model_dir 设置为包含mnist_cnn.pt
和此脚本的文件夹。 -
load_now
确定模型是否会在注册期间立即加载。默认值为“真”。如果设置为False
,则需要调用load
API 加载模型才能进行预测。
启动服务¶
$ uvicorn func_app:service --reload
INFO: Uvicorn running on http://127.0.0.1:8000 (Press CTRL+C to quit)
INFO: Started reloader process [xxxxx] using statreload
INFO: Started server process [xxxxx]
INFO: Waiting for application startup.
INFO: Application startup complete.
$ pinfer func_app:service --reload
Pinferencia: Frontend component streamlit is starting...
Pinferencia: Backend component uvicorn is starting...
测试服务¶
运行测试:
$ python test.py
Prediction: 4
不出意外,结果一样。
最后¶
使用 Pinferencia,您可以为任何模型提供服务。
您可以自己加载模型,就像您在进行离线预测时所做的那样。 这部分代码你早就已经写好了。
然后,只需使用 Pinferencia 注册模型,您的模型就会生效。
或者,您可以选择将代码重构为 Handler Class。旧的经典方式也适用于 Pinferencia。
这两个世界都适用于您的模型,经典音乐 和 摇滚乐。
是不是很棒!
现在您已经掌握了如何使用 Pinferencia 来:
- 注册任何模型、任何函数并把它们上线。
- 使用您的自定义处理程序为您的机器学习模型提供服务。