Skip to content

Commit

Permalink
upgrade inference model
Browse files Browse the repository at this point in the history
  • Loading branch information
TingquanGao committed Nov 22, 2024
1 parent f4fa153 commit 19d0c93
Show file tree
Hide file tree
Showing 36 changed files with 3,810 additions and 2 deletions.
2 changes: 1 addition & 1 deletion paddlex/inference/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from .models import create_predictor
from .new_models import create_predictor
from .pipelines import create_pipeline
from .utils.pp_option import PaddlePredictorOption
13 changes: 13 additions & 0 deletions paddlex/inference/common/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
129 changes: 129 additions & 0 deletions paddlex/inference/common/funcs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import math
import numpy as np
import cv2
import PIL
from PIL import Image, ImageFont, ImageDraw

from ...utils.fonts import PINGFANG_FONT_FILE_PATH
from ..utils.color_map import get_colormap, font_colormap


def create_font(txt, sz, font_path):
"""create font"""
font_size = int(sz[1] * 0.8)
font = ImageFont.truetype(font_path, font_size, encoding="utf-8")
if int(PIL.__version__.split(".")[0]) < 10:
length = font.getsize(txt)[0]
else:
length = font.getlength(txt)

if length > sz[0]:
font_size = int(font_size * sz[0] / length)
font = ImageFont.truetype(font_path, font_size, encoding="utf-8")
return font


def draw_box_txt_fine(img_size, box, txt, font_path):
"""draw box text"""
box_height = int(
math.sqrt((box[0][0] - box[3][0]) ** 2 + (box[0][1] - box[3][1]) ** 2)
)
box_width = int(
math.sqrt((box[0][0] - box[1][0]) ** 2 + (box[0][1] - box[1][1]) ** 2)
)

if box_height > 2 * box_width and box_height > 30:
img_text = Image.new("RGB", (box_height, box_width), (255, 255, 255))
draw_text = ImageDraw.Draw(img_text)
if txt:
font = create_font(txt, (box_height, box_width), font_path)
draw_text.text([0, 0], txt, fill=(0, 0, 0), font=font)
img_text = img_text.transpose(Image.ROTATE_270)
else:
img_text = Image.new("RGB", (box_width, box_height), (255, 255, 255))
draw_text = ImageDraw.Draw(img_text)
if txt:
font = create_font(txt, (box_width, box_height), font_path)
draw_text.text([0, 0], txt, fill=(0, 0, 0), font=font)

pts1 = np.float32(
[[0, 0], [box_width, 0], [box_width, box_height], [0, box_height]]
)
pts2 = np.array(box, dtype=np.float32)
M = cv2.getPerspectiveTransform(pts1, pts2)

img_text = np.array(img_text, dtype=np.uint8)
img_right_text = cv2.warpPerspective(
img_text,
M,
img_size,
flags=cv2.INTER_NEAREST,
borderMode=cv2.BORDER_CONSTANT,
borderValue=(255, 255, 255),
)
return img_right_text


def draw_box(img, boxes):
"""
Args:
img (PIL.Image.Image): PIL image
boxes (list): a list of dictionaries representing detection box information.
Returns:
img (PIL.Image.Image): visualized image
"""
font_size = int(0.024 * int(img.width)) + 2
font = ImageFont.truetype(PINGFANG_FONT_FILE_PATH, font_size, encoding="utf-8")

draw_thickness = int(max(img.size) * 0.005)
draw = ImageDraw.Draw(img)
clsid2color = {}
catid2fontcolor = {}
color_list = get_colormap(rgb=True)

for i, dt in enumerate(boxes):
clsid, bbox, score = dt["cls_id"], dt["coordinate"], dt["score"]
if clsid not in clsid2color:
color_index = i % len(color_list)
clsid2color[clsid] = color_list[color_index]
catid2fontcolor[clsid] = font_colormap(color_index)
color = tuple(clsid2color[clsid])
font_color = tuple(catid2fontcolor[clsid])

xmin, ymin, xmax, ymax = bbox
# draw bbox
draw.line(
[(xmin, ymin), (xmin, ymax), (xmax, ymax), (xmax, ymin), (xmin, ymin)],
width=draw_thickness,
fill=color,
)

# draw label
text = "{} {:.2f}".format(dt["label"], score)
if tuple(map(int, PIL.__version__.split("."))) <= (10, 0, 0):
tw, th = draw.textsize(text, font=font)
else:
left, top, right, bottom = draw.textbbox((0, 0), text, font)
tw, th = right - left, bottom - top
if ymin < th:
draw.rectangle([(xmin, ymin), (xmin + tw + 4, ymin + th + 1)], fill=color)
draw.text((xmin + 2, ymin - 2), text, fill=font_color, font=font)
else:
draw.rectangle([(xmin, ymin - th), (xmin + tw + 4, ymin + 1)], fill=color)
draw.text((xmin + 2, ymin - th - 2), text, fill=font_color, font=font)

return img
108 changes: 108 additions & 0 deletions paddlex/inference/new_models/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


from pathlib import Path
from typing import Any, Dict, Optional

from ...utils import errors
from ..utils.official_models import official_models
from .base import BasePredictor, BasicPredictor
from .image_classification import ClasPredictor

# from .text_detection import TextDetPredictor
# from .text_recognition import TextRecPredictor
# from .table_recognition import TablePredictor
# from .object_detection import DetPredictor

# from .instance_segmentation import InstanceSegPredictor
# from .semantic_segmentation import SegPredictor
# from .general_recognition import ShiTuRecPredictor
# from .ts_fc import TSFcPredictor
# from .ts_ad import TSAdPredictor
# from .ts_cls import TSClsPredictor
# from .image_unwarping import WarpPredictor
# from .multilabel_classification import MLClasPredictor
# from .anomaly_detection import UadPredictor
# from .formula_recognition import LaTeXOCRPredictor
# from .face_recognition import FaceRecPredictor


def _create_hp_predictor(
model_name, model_dir, device, config, hpi_params, *args, **kwargs
):
try:
from paddlex_hpi.models import HPPredictor
except ModuleNotFoundError:
raise RuntimeError(
"The PaddleX HPI plugin is not properly installed, and the high-performance model inference features are not available."
) from None
try:
predictor = HPPredictor.get(model_name)(
model_dir=model_dir,
config=config,
device=device,
*args,
hpi_params=hpi_params,
**kwargs,
)
except errors.others.ClassNotFoundException:
raise ValueError(
f"{model_name} is not supported by the PaddleX HPI plugin."
) from None
return predictor


def create_predictor(
model: str,
device=None,
pp_option=None,
use_hpip: bool = False,
hpi_params: Optional[Dict[str, Any]] = None,
*args,
**kwargs,
) -> BasePredictor:
model_dir = check_model(model)
config = BasePredictor.load_config(model_dir)
model_name = config["Global"]["model_name"]
if use_hpip:
return _create_hp_predictor(
model_name=model_name,
model_dir=model_dir,
config=config,
hpi_params=hpi_params,
device=device,
*args,
**kwargs,
)
else:
return BasicPredictor.get(model_name)(
model_dir=model_dir,
config=config,
device=device,
pp_option=pp_option,
*args,
**kwargs,
)


def check_model(model):
if Path(model).exists():
return Path(model)
elif model in official_models:
return official_models[model]
else:
raise Exception(
f"The model ({model}) is no exists! Please using directory of local model files or model name supported by PaddleX!"
)
18 changes: 18 additions & 0 deletions paddlex/inference/new_models/base/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from .base_component import BaseComponent
from .base_predictor import BasePredictor
from .basic_predictor import BasicPredictor
from .result import BaseResult, CVResult
89 changes: 89 additions & 0 deletions paddlex/inference/new_models/base/base_component.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import inspect
from abc import ABC, abstractmethod
from types import GeneratorType

from ....utils.flags import INFER_BENCHMARK
from ....utils import logging
from ...utils.benchmark import Timer


class BaseComponent(ABC):

def __init__(self):
if INFER_BENCHMARK:
self.timer = Timer()
self.apply = self.timer.watch_func(self.apply)

def __call__(self, batch_data):
kwargs = {k: batch_data.get_by_key(v) for k, v in self.inputs.items()}
output = self.apply(**kwargs)
if not output:
return batch_data
batch_data.update(self.name, output)
return batch_data

def set_inputs(self, inputs):
assert isinstance(inputs, dict)
sig = inspect.signature(self.apply)
if len(sig.parameters) == 0:
if inputs:
raise Exception
for param in sig.parameters.values():
if param.kind == inspect.Parameter.VAR_KEYWORD:
logging.debug(
f"The apply function parameter of {self.__class__.__name__} is **kwargs, so would not inspect!"
)
continue
if param.default == inspect.Parameter.empty and param.name not in inputs:
raise Exception(
f"The parameter ({param.name}) is needed by {self.__class__.__name__}, but only found in keys of {inputs}!"
)
self.inputs = inputs

@classmethod
def get_input_keys(cls) -> list:
return cls.input_keys

@property
def name(self):
return getattr(self, "NAME", self.__class__.__name__)

@abstractmethod
def apply(self, input):
raise NotImplementedError


class ComponentsEngine(object):
def __init__(self, cmpts, config):
self._cmpts = cmpts
self._config = config
self.keys = list(cmpts.keys())
self._set_inputs()

def _set_inputs(self):
for idx, name in enumerate(self._cmpts):
cmpt = self._cmpts[name]
cmpt_name = cmpt.__class__.__name__
input_cfg = self._config[cmpt_name]
cmpt.set_inputs(input_cfg)

def __call__(self, data, i=0):
data = self._cmpts[self.keys[i]](data)
if i + 1 < len(self._cmpts):
return self.__call__(data, i + 1)
else:
return data
Loading

0 comments on commit 19d0c93

Please sign in to comment.