-
Notifications
You must be signed in to change notification settings - Fork 966
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
f4fa153
commit 19d0c93
Showing
36 changed files
with
3,810 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,129 @@ | ||
# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import math | ||
import numpy as np | ||
import cv2 | ||
import PIL | ||
from PIL import Image, ImageFont, ImageDraw | ||
|
||
from ...utils.fonts import PINGFANG_FONT_FILE_PATH | ||
from ..utils.color_map import get_colormap, font_colormap | ||
|
||
|
||
def create_font(txt, sz, font_path): | ||
"""create font""" | ||
font_size = int(sz[1] * 0.8) | ||
font = ImageFont.truetype(font_path, font_size, encoding="utf-8") | ||
if int(PIL.__version__.split(".")[0]) < 10: | ||
length = font.getsize(txt)[0] | ||
else: | ||
length = font.getlength(txt) | ||
|
||
if length > sz[0]: | ||
font_size = int(font_size * sz[0] / length) | ||
font = ImageFont.truetype(font_path, font_size, encoding="utf-8") | ||
return font | ||
|
||
|
||
def draw_box_txt_fine(img_size, box, txt, font_path): | ||
"""draw box text""" | ||
box_height = int( | ||
math.sqrt((box[0][0] - box[3][0]) ** 2 + (box[0][1] - box[3][1]) ** 2) | ||
) | ||
box_width = int( | ||
math.sqrt((box[0][0] - box[1][0]) ** 2 + (box[0][1] - box[1][1]) ** 2) | ||
) | ||
|
||
if box_height > 2 * box_width and box_height > 30: | ||
img_text = Image.new("RGB", (box_height, box_width), (255, 255, 255)) | ||
draw_text = ImageDraw.Draw(img_text) | ||
if txt: | ||
font = create_font(txt, (box_height, box_width), font_path) | ||
draw_text.text([0, 0], txt, fill=(0, 0, 0), font=font) | ||
img_text = img_text.transpose(Image.ROTATE_270) | ||
else: | ||
img_text = Image.new("RGB", (box_width, box_height), (255, 255, 255)) | ||
draw_text = ImageDraw.Draw(img_text) | ||
if txt: | ||
font = create_font(txt, (box_width, box_height), font_path) | ||
draw_text.text([0, 0], txt, fill=(0, 0, 0), font=font) | ||
|
||
pts1 = np.float32( | ||
[[0, 0], [box_width, 0], [box_width, box_height], [0, box_height]] | ||
) | ||
pts2 = np.array(box, dtype=np.float32) | ||
M = cv2.getPerspectiveTransform(pts1, pts2) | ||
|
||
img_text = np.array(img_text, dtype=np.uint8) | ||
img_right_text = cv2.warpPerspective( | ||
img_text, | ||
M, | ||
img_size, | ||
flags=cv2.INTER_NEAREST, | ||
borderMode=cv2.BORDER_CONSTANT, | ||
borderValue=(255, 255, 255), | ||
) | ||
return img_right_text | ||
|
||
|
||
def draw_box(img, boxes): | ||
""" | ||
Args: | ||
img (PIL.Image.Image): PIL image | ||
boxes (list): a list of dictionaries representing detection box information. | ||
Returns: | ||
img (PIL.Image.Image): visualized image | ||
""" | ||
font_size = int(0.024 * int(img.width)) + 2 | ||
font = ImageFont.truetype(PINGFANG_FONT_FILE_PATH, font_size, encoding="utf-8") | ||
|
||
draw_thickness = int(max(img.size) * 0.005) | ||
draw = ImageDraw.Draw(img) | ||
clsid2color = {} | ||
catid2fontcolor = {} | ||
color_list = get_colormap(rgb=True) | ||
|
||
for i, dt in enumerate(boxes): | ||
clsid, bbox, score = dt["cls_id"], dt["coordinate"], dt["score"] | ||
if clsid not in clsid2color: | ||
color_index = i % len(color_list) | ||
clsid2color[clsid] = color_list[color_index] | ||
catid2fontcolor[clsid] = font_colormap(color_index) | ||
color = tuple(clsid2color[clsid]) | ||
font_color = tuple(catid2fontcolor[clsid]) | ||
|
||
xmin, ymin, xmax, ymax = bbox | ||
# draw bbox | ||
draw.line( | ||
[(xmin, ymin), (xmin, ymax), (xmax, ymax), (xmax, ymin), (xmin, ymin)], | ||
width=draw_thickness, | ||
fill=color, | ||
) | ||
|
||
# draw label | ||
text = "{} {:.2f}".format(dt["label"], score) | ||
if tuple(map(int, PIL.__version__.split("."))) <= (10, 0, 0): | ||
tw, th = draw.textsize(text, font=font) | ||
else: | ||
left, top, right, bottom = draw.textbbox((0, 0), text, font) | ||
tw, th = right - left, bottom - top | ||
if ymin < th: | ||
draw.rectangle([(xmin, ymin), (xmin + tw + 4, ymin + th + 1)], fill=color) | ||
draw.text((xmin + 2, ymin - 2), text, fill=font_color, font=font) | ||
else: | ||
draw.rectangle([(xmin, ymin - th), (xmin + tw + 4, ymin + 1)], fill=color) | ||
draw.text((xmin + 2, ymin - th - 2), text, fill=font_color, font=font) | ||
|
||
return img |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,108 @@ | ||
# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
|
||
from pathlib import Path | ||
from typing import Any, Dict, Optional | ||
|
||
from ...utils import errors | ||
from ..utils.official_models import official_models | ||
from .base import BasePredictor, BasicPredictor | ||
from .image_classification import ClasPredictor | ||
|
||
# from .text_detection import TextDetPredictor | ||
# from .text_recognition import TextRecPredictor | ||
# from .table_recognition import TablePredictor | ||
# from .object_detection import DetPredictor | ||
|
||
# from .instance_segmentation import InstanceSegPredictor | ||
# from .semantic_segmentation import SegPredictor | ||
# from .general_recognition import ShiTuRecPredictor | ||
# from .ts_fc import TSFcPredictor | ||
# from .ts_ad import TSAdPredictor | ||
# from .ts_cls import TSClsPredictor | ||
# from .image_unwarping import WarpPredictor | ||
# from .multilabel_classification import MLClasPredictor | ||
# from .anomaly_detection import UadPredictor | ||
# from .formula_recognition import LaTeXOCRPredictor | ||
# from .face_recognition import FaceRecPredictor | ||
|
||
|
||
def _create_hp_predictor( | ||
model_name, model_dir, device, config, hpi_params, *args, **kwargs | ||
): | ||
try: | ||
from paddlex_hpi.models import HPPredictor | ||
except ModuleNotFoundError: | ||
raise RuntimeError( | ||
"The PaddleX HPI plugin is not properly installed, and the high-performance model inference features are not available." | ||
) from None | ||
try: | ||
predictor = HPPredictor.get(model_name)( | ||
model_dir=model_dir, | ||
config=config, | ||
device=device, | ||
*args, | ||
hpi_params=hpi_params, | ||
**kwargs, | ||
) | ||
except errors.others.ClassNotFoundException: | ||
raise ValueError( | ||
f"{model_name} is not supported by the PaddleX HPI plugin." | ||
) from None | ||
return predictor | ||
|
||
|
||
def create_predictor( | ||
model: str, | ||
device=None, | ||
pp_option=None, | ||
use_hpip: bool = False, | ||
hpi_params: Optional[Dict[str, Any]] = None, | ||
*args, | ||
**kwargs, | ||
) -> BasePredictor: | ||
model_dir = check_model(model) | ||
config = BasePredictor.load_config(model_dir) | ||
model_name = config["Global"]["model_name"] | ||
if use_hpip: | ||
return _create_hp_predictor( | ||
model_name=model_name, | ||
model_dir=model_dir, | ||
config=config, | ||
hpi_params=hpi_params, | ||
device=device, | ||
*args, | ||
**kwargs, | ||
) | ||
else: | ||
return BasicPredictor.get(model_name)( | ||
model_dir=model_dir, | ||
config=config, | ||
device=device, | ||
pp_option=pp_option, | ||
*args, | ||
**kwargs, | ||
) | ||
|
||
|
||
def check_model(model): | ||
if Path(model).exists(): | ||
return Path(model) | ||
elif model in official_models: | ||
return official_models[model] | ||
else: | ||
raise Exception( | ||
f"The model ({model}) is no exists! Please using directory of local model files or model name supported by PaddleX!" | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from .base_component import BaseComponent | ||
from .base_predictor import BasePredictor | ||
from .basic_predictor import BasicPredictor | ||
from .result import BaseResult, CVResult |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,89 @@ | ||
# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import inspect | ||
from abc import ABC, abstractmethod | ||
from types import GeneratorType | ||
|
||
from ....utils.flags import INFER_BENCHMARK | ||
from ....utils import logging | ||
from ...utils.benchmark import Timer | ||
|
||
|
||
class BaseComponent(ABC): | ||
|
||
def __init__(self): | ||
if INFER_BENCHMARK: | ||
self.timer = Timer() | ||
self.apply = self.timer.watch_func(self.apply) | ||
|
||
def __call__(self, batch_data): | ||
kwargs = {k: batch_data.get_by_key(v) for k, v in self.inputs.items()} | ||
output = self.apply(**kwargs) | ||
if not output: | ||
return batch_data | ||
batch_data.update(self.name, output) | ||
return batch_data | ||
|
||
def set_inputs(self, inputs): | ||
assert isinstance(inputs, dict) | ||
sig = inspect.signature(self.apply) | ||
if len(sig.parameters) == 0: | ||
if inputs: | ||
raise Exception | ||
for param in sig.parameters.values(): | ||
if param.kind == inspect.Parameter.VAR_KEYWORD: | ||
logging.debug( | ||
f"The apply function parameter of {self.__class__.__name__} is **kwargs, so would not inspect!" | ||
) | ||
continue | ||
if param.default == inspect.Parameter.empty and param.name not in inputs: | ||
raise Exception( | ||
f"The parameter ({param.name}) is needed by {self.__class__.__name__}, but only found in keys of {inputs}!" | ||
) | ||
self.inputs = inputs | ||
|
||
@classmethod | ||
def get_input_keys(cls) -> list: | ||
return cls.input_keys | ||
|
||
@property | ||
def name(self): | ||
return getattr(self, "NAME", self.__class__.__name__) | ||
|
||
@abstractmethod | ||
def apply(self, input): | ||
raise NotImplementedError | ||
|
||
|
||
class ComponentsEngine(object): | ||
def __init__(self, cmpts, config): | ||
self._cmpts = cmpts | ||
self._config = config | ||
self.keys = list(cmpts.keys()) | ||
self._set_inputs() | ||
|
||
def _set_inputs(self): | ||
for idx, name in enumerate(self._cmpts): | ||
cmpt = self._cmpts[name] | ||
cmpt_name = cmpt.__class__.__name__ | ||
input_cfg = self._config[cmpt_name] | ||
cmpt.set_inputs(input_cfg) | ||
|
||
def __call__(self, data, i=0): | ||
data = self._cmpts[self.keys[i]](data) | ||
if i + 1 < len(self._cmpts): | ||
return self.__call__(data, i + 1) | ||
else: | ||
return data |
Oops, something went wrong.