自定义模板
尽管有内置模板,但它永远不足以涵盖所有场景。
Pinferencia 支持自定义模板。自定义模板并在您的服务中使用它很容易。
首先让我们尝试创建一个简单的模板:
- 输入数字列表。
- 显示数字的平均值、最大值和最小值。
模型
模型很简单,服务可以定义为:
app.py |
---|
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15 | from typing import List
from pinferencia import Server
def stat(data: List[float]) -> dict:
return {
"mean": sum(data) / len(data),
"max": max(data),
"min": min(data),
}
service = Server()
service.register(model_name="stat", model=stat, metadata={"task": "Stat"})
|
模板
Pinferencia 提供了 BaseTemplate
来扩展以构建自定义模板。
JSON 输入
首先,我们将页面分为两列,并分别创建一个 JSON 输入字段和显示字段。
frontend.py |
---|
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27 | import streamlit as st
from pinferencia.frontend.app import Server
from pinferencia.frontend.templates.base import BaseTemplate
class StatTemplate(BaseTemplate):
title = (
'<span style="color:salmon;">Numbers</span> '
'<span style="color:slategray;">Statistics</span>'
)
def render(self):
super().render()
json_template = "[]"
col1, col2 = st.columns(2)
col2.write("Request Preview")
raw_text = col1.text_area("Raw Data", value=json_template, height=150)
col2.json(raw_text)
backend_address = "http://127.0.0.1:8000"
service = Server(
backend_server=f"{backend_address}",
custom_templates={"Stat": StatTemplate},
)
|
启动服务
$ pinfer sum_mnist:service --frontend-script=sum_mnist_frontend.py
Pinferencia: Frontend component streamlit is starting...
Pinferencia: Backend component uvicorn is starting...
并打开浏览器你会看到:
调用后端并显示结果
添加以下高亮显示的代码以将请求发送到后端并显示结果。
frontend.py |
---|
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40 | import json
import streamlit as st
from pinferencia.frontend.app import Server
from pinferencia.frontend.templates.base import BaseTemplate
class StatTemplate(BaseTemplate):
title = (
'<span style="color:salmon;">Numbers</span> '
'<span style="color:slategray;">Statistics</span>'
)
def render(self):
super().render()
json_template = "[]"
col1, col2 = st.columns(2)
col2.write("Request Preview")
raw_text = col1.text_area("Raw Data", value=json_template, height=150)
col2.json(raw_text)
pred_btn = st.button("Run") # (1)
if pred_btn:
with st.spinner("Wait for result"): # (2)
prediction = self.predict(json.loads(raw_text)) # (3)
st.write("Statistics")
result_col1, result_col2, result_col3 = st.columns(3) # (4)
result_col1.metric(label="Max", value=prediction.get("max"))
result_col2.metric(label="Min", value=prediction.get("min"))
result_col3.metric(label="Mean", value=prediction.get("mean"))
backend_address = "http://127.0.0.1:8000"
service = Server(
backend_server=f"{backend_address}",
custom_templates={"Stat": StatTemplate},
)
|
- 提供一个按钮来触发预测。
- 发送请求时显示一个等待效果。
- 将数据发送到后端。
- 将结果分为三列展示。
再次启动服务,您将看到: