Skip to content

Commit

Permalink
upgrade inference model
Browse files Browse the repository at this point in the history
  • Loading branch information
TingquanGao committed Nov 11, 2024
1 parent 5987efe commit 4ae6b29
Show file tree
Hide file tree
Showing 74 changed files with 1,202 additions and 1,199 deletions.
2 changes: 1 addition & 1 deletion paddlex/configs/ts_anomaly_detection/DLinear_ad.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Global:
model: DLinear_ad
mode: check_dataset # check_dataset/train/evaluate/predict
mode: predict # check_dataset/train/evaluate/predict
dataset_dir: "/paddle/dataset/paddlex/ts_ad/ts_anomaly_examples/"
device: gpu:0
output: "output"
Expand Down
File renamed without changes.
129 changes: 129 additions & 0 deletions paddlex/inference/common/funcs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import math
import numpy as np
import cv2
import PIL
from PIL import Image, ImageFont, ImageDraw

from ...utils.fonts import PINGFANG_FONT_FILE_PATH
from ..utils.color_map import get_colormap, font_colormap


def create_font(txt, sz, font_path):
"""create font"""
font_size = int(sz[1] * 0.8)
font = ImageFont.truetype(font_path, font_size, encoding="utf-8")
if int(PIL.__version__.split(".")[0]) < 10:
length = font.getsize(txt)[0]
else:
length = font.getlength(txt)

if length > sz[0]:
font_size = int(font_size * sz[0] / length)
font = ImageFont.truetype(font_path, font_size, encoding="utf-8")
return font


def draw_box_txt_fine(img_size, box, txt, font_path):
"""draw box text"""
box_height = int(
math.sqrt((box[0][0] - box[3][0]) ** 2 + (box[0][1] - box[3][1]) ** 2)
)
box_width = int(
math.sqrt((box[0][0] - box[1][0]) ** 2 + (box[0][1] - box[1][1]) ** 2)
)

if box_height > 2 * box_width and box_height > 30:
img_text = Image.new("RGB", (box_height, box_width), (255, 255, 255))
draw_text = ImageDraw.Draw(img_text)
if txt:
font = create_font(txt, (box_height, box_width), font_path)
draw_text.text([0, 0], txt, fill=(0, 0, 0), font=font)
img_text = img_text.transpose(Image.ROTATE_270)
else:
img_text = Image.new("RGB", (box_width, box_height), (255, 255, 255))
draw_text = ImageDraw.Draw(img_text)
if txt:
font = create_font(txt, (box_width, box_height), font_path)
draw_text.text([0, 0], txt, fill=(0, 0, 0), font=font)

pts1 = np.float32(
[[0, 0], [box_width, 0], [box_width, box_height], [0, box_height]]
)
pts2 = np.array(box, dtype=np.float32)
M = cv2.getPerspectiveTransform(pts1, pts2)

img_text = np.array(img_text, dtype=np.uint8)
img_right_text = cv2.warpPerspective(
img_text,
M,
img_size,
flags=cv2.INTER_NEAREST,
borderMode=cv2.BORDER_CONSTANT,
borderValue=(255, 255, 255),
)
return img_right_text


def draw_box(img, boxes):
"""
Args:
img (PIL.Image.Image): PIL image
boxes (list): a list of dictionaries representing detection box information.
Returns:
img (PIL.Image.Image): visualized image
"""
font_size = int(0.024 * int(img.width)) + 2
font = ImageFont.truetype(PINGFANG_FONT_FILE_PATH, font_size, encoding="utf-8")

draw_thickness = int(max(img.size) * 0.005)
draw = ImageDraw.Draw(img)
clsid2color = {}
catid2fontcolor = {}
color_list = get_colormap(rgb=True)

for i, dt in enumerate(boxes):
clsid, bbox, score = dt["cls_id"], dt["coordinate"], dt["score"]
if clsid not in clsid2color:
color_index = i % len(color_list)
clsid2color[clsid] = color_list[color_index]
catid2fontcolor[clsid] = font_colormap(color_index)
color = tuple(clsid2color[clsid])
font_color = tuple(catid2fontcolor[clsid])

xmin, ymin, xmax, ymax = bbox
# draw bbox
draw.line(
[(xmin, ymin), (xmin, ymax), (xmax, ymax), (xmax, ymin), (xmin, ymin)],
width=draw_thickness,
fill=color,
)

# draw label
text = "{} {:.2f}".format(dt["label"], score)
if tuple(map(int, PIL.__version__.split("."))) <= (10, 0, 0):
tw, th = draw.textsize(text, font=font)
else:
left, top, right, bottom = draw.textbbox((0, 0), text, font)
tw, th = right - left, bottom - top
if ymin < th:
draw.rectangle([(xmin, ymin), (xmin + tw + 4, ymin + th + 1)], fill=color)
draw.text((xmin + 2, ymin - 2), text, fill=font_color, font=font)
else:
draw.rectangle([(xmin, ymin - th), (xmin + tw + 4, ymin + 1)], fill=color)
draw.text((xmin + 2, ymin - th - 2), text, fill=font_color, font=font)

return img
Loading

0 comments on commit 4ae6b29

Please sign in to comment.