注册模型
注册一个模型非常简单:
| service.register(
model_name="mymodel",
model=model,
entrypoint="predict",
)
|
如果我有多个模型,或者有多个版本呢?
你可以注册多个模型,每个模型可以有不同的版本:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29 | service.register(
model_name="my-model",
model=my_model,
entrypoint="predict",
)
service.register(
model_name="my-model",
model=my_model_v1,
entrypoint="predict",
version_name="v1,
)
service.register(
model_name="your-model",
model=your_model,
entrypoint="predict",
)
service.register(
model_name="your-model",
model=your_model_v1,
entrypoint="predict",
version_name="v1,
)
service.register(
model_name="your-model",
model=your_model_v2,
entrypoint="predict",
version_name="v2,
)
|
参数
参数 | 类似 | 默认值(如有) | 细节 |
model_name | str | | 模型名称 |
model | object | | 模型Python对象,或者模型文件路径 |
version_name | str | None | 版本名称 |
entrypoint | str | None | 用来预测的函数名称 |
metadata | dict | None | 模型基础信息 |
handler | object | None | Hanlder 类 |
load_now | bool | True | 是否立刻载入模型 |
例子
Model Name
| from pinferencia import Server
def predict(data):
return sum(data)
service = Server()
service.register(
model_name="mymodel",
model=predict,
)
|
Model
Version名称
没有提供版本名的模型会用 default
版本名.
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20 | from pinferencia import Server
def add(data):
return data[0] + data[1]
def substract(data):
return data[0] + data[1]
service = Server()
service.register(
model_name="mymodel",
model=add,
version_name="add", # (1)
)
service.register(
model_name="mymodel",
model=substract,
version_name="substract", # (2)
)
|
- 预测地址在 http://127.0.0.1/v1/models/mymodel/versions/add/predict
- 预测地址在 http://127.0.0.1/v1/models/mymodel/versions/substract/predict
Entrypoint
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26 | from pinferencia import Server
class MyModel:
def add(self, data):
return data[0] + data[1]
def substract(self, data):
return data[0] - data[1]
model = MyModel()
service = Server()
service.register(
model_name="mymodel",
model=model,
version_name="add", # (1)
entrypoint="add", # (3)
)
service.register(
model_name="mymodel",
model=model,
version_name="substract", # (2)
entrypoint="substract", # (4)
)
|
- 预测地址在 http://127.0.0.1/v1/models/mymodel/versions/add/predict
- 预测地址在 http://127.0.0.1/v1/models/mymodel/versions/substract/predict
add
函数会被用来预测. substract
函数会被用来预测.
默认API
Pinferencia 默认metadata架构支持 platform 和 device
这些信息仅供展示。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15 | from pinferencia import Server
def predict(data):
return sum(data)
service = Server()
service.register(
model_name="mymodel",
model=predict,
metadata={
"platform": "Linux",
"device": "CPU+GPU",
}
)
|
Kserve API
Pinferencia 同时支持 Kserve API.
对于 Kserve V2, 模型metadata支持: - platform - inputs - outputs
inputs 和 outputs 会决定模型收到的数据和返回的数据类型.
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26 | from pinferencia import Server
def predict(data):
return sum(data)
service = Server(api="kserve") # (1)
service.register(
model_name="mymodel",
model=predict,
metadata={
"platform": "mac os",
"inputs": [
{
"name": "integers", # (2)
"datatype": "int64",
"shape": [1],
"data": [1, 2, 3],
}
],
"outputs": [
{"name": "sum", "datatype": "int64", "shape": -1, "data": 6}, # (3)
{"name": "product", "datatype": "int64", "shape": -1, "data": 6},
],
}
)
|
- 如果要使用 Kserve API 需要在实例化服务时设置 api="kserve"。
- 如果请求包含多组数据,只有
intergers
数据会被传递给模型。 - 输出数据会被转换为
int64
。datatype
字段仅支持numpy
数据类型. 如果类型转换失败,响应里会多出 error
字段。
Handler
关于Handler的细节,请查看Handlers.
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20 | from pinferencia import Server
from pinferencia.handlers import PickleHandler
class MyPrintHandler(PickleHandler):
def predict(self, data):
print(data)
return self.model.predict(data)
def predict(data):
return sum(data)
service = Server()
service.register(
model_name="mymodel",
model=predict,
handler=MyPrintHandler
)
|
Load Now
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18 | import joblib
from pinferencia import Server
class JoblibHandler(BaseHandler):
def load_model(self):
return joblib.load(self.model_path)
service = Server(model_dir="/opt/models")
service.register(
model_name="mymodel",
model="/path/to/model.joblib",
entrypoint="predict",
handler=JoblibHandler,
load_now=True,
)
|