CloudAnts's picture
my
32652fd
raw
history blame
22.7 kB
#![allow(clippy::type_complexity)]
use anyhow::Result;
use image::{DynamicImage, GenericImageView, ImageBuffer};
use ndarray::{s, Array, Axis, IxDyn};
use rand::{thread_rng, Rng};
use std::path::PathBuf;
use crate::{
check_font, gen_time_string, non_max_suppression, Args, Batch, Bbox, Embedding, OrtBackend,
OrtConfig, OrtEP, Point2, YOLOResult, YOLOTask, SKELETON,
};
pub struct YOLOv8 {
// YOLOv8 model for all yolo-tasks
engine: OrtBackend,
nc: u32,
nk: u32,
nm: u32,
height: u32,
width: u32,
batch: u32,
task: YOLOTask,
conf: f32,
kconf: f32,
iou: f32,
names: Vec<String>,
color_palette: Vec<(u8, u8, u8)>,
profile: bool,
plot: bool,
}
impl YOLOv8 {
pub fn new(config: Args) -> Result<Self> {
// execution provider
let ep = if config.trt {
OrtEP::Trt(config.device_id)
} else if config.cuda {
OrtEP::Cuda(config.device_id)
} else {
OrtEP::Cpu
};
// batch
let batch = Batch {
opt: config.batch,
min: config.batch_min,
max: config.batch_max,
};
// build ort engine
let ort_args = OrtConfig {
ep,
batch,
f: config.model,
task: config.task,
trt_fp16: config.fp16,
image_size: (config.height, config.width),
};
let engine = OrtBackend::build(ort_args)?;
// get batch, height, width, tasks, nc, nk, nm
let (batch, height, width, task) = (
engine.batch(),
engine.height(),
engine.width(),
engine.task(),
);
let nc = engine.nc().or(config.nc).unwrap_or_else(|| {
panic!("Failed to get num_classes, make it explicit with `--nc`");
});
let (nk, nm) = match task {
YOLOTask::Pose => {
let nk = engine.nk().or(config.nk).unwrap_or_else(|| {
panic!("Failed to get num_keypoints, make it explicit with `--nk`");
});
(nk, 0)
}
YOLOTask::Segment => {
let nm = engine.nm().or(config.nm).unwrap_or_else(|| {
panic!("Failed to get num_masks, make it explicit with `--nm`");
});
(0, nm)
}
_ => (0, 0),
};
// class names
let names = engine.names().unwrap_or(vec!["Unknown".to_string()]);
// color palette
let mut rng = thread_rng();
let color_palette: Vec<_> = names
.iter()
.map(|_| {
(
rng.gen_range(0..=255),
rng.gen_range(0..=255),
rng.gen_range(0..=255),
)
})
.collect();
Ok(Self {
engine,
names,
conf: config.conf,
kconf: config.kconf,
iou: config.iou,
color_palette,
profile: config.profile,
plot: config.plot,
nc,
nk,
nm,
height,
width,
batch,
task,
})
}
pub fn scale_wh(&self, w0: f32, h0: f32, w1: f32, h1: f32) -> (f32, f32, f32) {
let r = (w1 / w0).min(h1 / h0);
(r, (w0 * r).round(), (h0 * r).round())
}
pub fn preprocess(&mut self, xs: &Vec<DynamicImage>) -> Result<Array<f32, IxDyn>> {
let mut ys =
Array::ones((xs.len(), 3, self.height() as usize, self.width() as usize)).into_dyn();
ys.fill(144.0 / 255.0);
for (idx, x) in xs.iter().enumerate() {
let img = match self.task() {
YOLOTask::Classify => x.resize_exact(
self.width(),
self.height(),
image::imageops::FilterType::Triangle,
),
_ => {
let (w0, h0) = x.dimensions();
let w0 = w0 as f32;
let h0 = h0 as f32;
let (_, w_new, h_new) =
self.scale_wh(w0, h0, self.width() as f32, self.height() as f32); // f32 round
x.resize_exact(
w_new as u32,
h_new as u32,
if let YOLOTask::Segment = self.task() {
image::imageops::FilterType::CatmullRom
} else {
image::imageops::FilterType::Triangle
},
)
}
};
for (x, y, rgb) in img.pixels() {
let x = x as usize;
let y = y as usize;
let [r, g, b, _] = rgb.0;
ys[[idx, 0, y, x]] = (r as f32) / 255.0;
ys[[idx, 1, y, x]] = (g as f32) / 255.0;
ys[[idx, 2, y, x]] = (b as f32) / 255.0;
}
}
Ok(ys)
}
pub fn run(&mut self, xs: &Vec<DynamicImage>) -> Result<Vec<YOLOResult>> {
// pre-process
let t_pre = std::time::Instant::now();
let xs_ = self.preprocess(xs)?;
if self.profile {
println!("[Model Preprocess]: {:?}", t_pre.elapsed());
}
// run
let t_run = std::time::Instant::now();
let ys = self.engine.run(xs_, self.profile)?;
if self.profile {
println!("[Model Inference]: {:?}", t_run.elapsed());
}
// post-process
let t_post = std::time::Instant::now();
let ys = self.postprocess(ys, xs)?;
if self.profile {
println!("[Model Postprocess]: {:?}", t_post.elapsed());
}
// plot and save
if self.plot {
self.plot_and_save(&ys, xs, Some(&SKELETON));
}
Ok(ys)
}
pub fn postprocess(
&self,
xs: Vec<Array<f32, IxDyn>>,
xs0: &[DynamicImage],
) -> Result<Vec<YOLOResult>> {
if let YOLOTask::Classify = self.task() {
let mut ys = Vec::new();
let preds = &xs[0];
for batch in preds.axis_iter(Axis(0)) {
ys.push(YOLOResult::new(
Some(Embedding::new(batch.into_owned())),
None,
None,
None,
));
}
Ok(ys)
} else {
const CXYWH_OFFSET: usize = 4; // cxcywh
const KPT_STEP: usize = 3; // xyconf
let preds = &xs[0];
let protos = {
if xs.len() > 1 {
Some(&xs[1])
} else {
None
}
};
let mut ys = Vec::new();
for (idx, anchor) in preds.axis_iter(Axis(0)).enumerate() {
// [bs, 4 + nc + nm, anchors]
// input image
let width_original = xs0[idx].width() as f32;
let height_original = xs0[idx].height() as f32;
let ratio = (self.width() as f32 / width_original)
.min(self.height() as f32 / height_original);
// save each result
let mut data: Vec<(Bbox, Option<Vec<Point2>>, Option<Vec<f32>>)> = Vec::new();
for pred in anchor.axis_iter(Axis(1)) {
// split preds for different tasks
let bbox = pred.slice(s![0..CXYWH_OFFSET]);
let clss = pred.slice(s![CXYWH_OFFSET..CXYWH_OFFSET + self.nc() as usize]);
let kpts = {
if let YOLOTask::Pose = self.task() {
Some(pred.slice(s![pred.len() - KPT_STEP * self.nk() as usize..]))
} else {
None
}
};
let coefs = {
if let YOLOTask::Segment = self.task() {
Some(pred.slice(s![pred.len() - self.nm() as usize..]).to_vec())
} else {
None
}
};
// confidence and id
let (id, &confidence) = clss
.into_iter()
.enumerate()
.reduce(|max, x| if x.1 > max.1 { x } else { max })
.unwrap(); // definitely will not panic!
// confidence filter
if confidence < self.conf {
continue;
}
// bbox re-scale
let cx = bbox[0] / ratio;
let cy = bbox[1] / ratio;
let w = bbox[2] / ratio;
let h = bbox[3] / ratio;
let x = cx - w / 2.;
let y = cy - h / 2.;
let y_bbox = Bbox::new(
x.max(0.0f32).min(width_original),
y.max(0.0f32).min(height_original),
w,
h,
id,
confidence,
);
// kpts
let y_kpts = {
if let Some(kpts) = kpts {
let mut kpts_ = Vec::new();
// rescale
for i in 0..self.nk() as usize {
let kx = kpts[KPT_STEP * i] / ratio;
let ky = kpts[KPT_STEP * i + 1] / ratio;
let kconf = kpts[KPT_STEP * i + 2];
if kconf < self.kconf {
kpts_.push(Point2::default());
} else {
kpts_.push(Point2::new_with_conf(
kx.max(0.0f32).min(width_original),
ky.max(0.0f32).min(height_original),
kconf,
));
}
}
Some(kpts_)
} else {
None
}
};
// data merged
data.push((y_bbox, y_kpts, coefs));
}
// nms
non_max_suppression(&mut data, self.iou);
// decode
let mut y_bboxes: Vec<Bbox> = Vec::new();
let mut y_kpts: Vec<Vec<Point2>> = Vec::new();
let mut y_masks: Vec<Vec<u8>> = Vec::new();
for elem in data.into_iter() {
if let Some(kpts) = elem.1 {
y_kpts.push(kpts)
}
// decode masks
if let Some(coefs) = elem.2 {
let proto = protos.unwrap().slice(s![idx, .., .., ..]);
let (nm, nh, nw) = proto.dim();
// coefs * proto -> mask
let coefs = Array::from_shape_vec((1, nm), coefs)?; // (n, nm)
let proto = proto.to_owned().into_shape((nm, nh * nw))?; // (nm, nh*nw)
let mask = coefs.dot(&proto).into_shape((nh, nw, 1))?; // (nh, nw, n)
// build image from ndarray
let mask_im: ImageBuffer<image::Luma<_>, Vec<f32>> =
match ImageBuffer::from_raw(nw as u32, nh as u32, mask.into_raw_vec()) {
Some(image) => image,
None => panic!("can not create image from ndarray"),
};
let mut mask_im = image::DynamicImage::from(mask_im); // -> dyn
// rescale masks
let (_, w_mask, h_mask) =
self.scale_wh(width_original, height_original, nw as f32, nh as f32);
let mask_cropped = mask_im.crop(0, 0, w_mask as u32, h_mask as u32);
let mask_original = mask_cropped.resize_exact(
// resize_to_fill
width_original as u32,
height_original as u32,
match self.task() {
YOLOTask::Segment => image::imageops::FilterType::CatmullRom,
_ => image::imageops::FilterType::Triangle,
},
);
// crop-mask with bbox
let mut mask_original_cropped = mask_original.into_luma8();
for y in 0..height_original as usize {
for x in 0..width_original as usize {
if x < elem.0.xmin() as usize
|| x > elem.0.xmax() as usize
|| y < elem.0.ymin() as usize
|| y > elem.0.ymax() as usize
{
mask_original_cropped.put_pixel(
x as u32,
y as u32,
image::Luma([0u8]),
);
}
}
}
y_masks.push(mask_original_cropped.into_raw());
}
y_bboxes.push(elem.0);
}
// save each result
let y = YOLOResult {
probs: None,
bboxes: if !y_bboxes.is_empty() {
Some(y_bboxes)
} else {
None
},
keypoints: if !y_kpts.is_empty() {
Some(y_kpts)
} else {
None
},
masks: if !y_masks.is_empty() {
Some(y_masks)
} else {
None
},
};
ys.push(y);
}
Ok(ys)
}
}
pub fn plot_and_save(
&self,
ys: &[YOLOResult],
xs0: &[DynamicImage],
skeletons: Option<&[(usize, usize)]>,
) {
// check font then load
let font = check_font("Arial.ttf");
for (_idb, (img0, y)) in xs0.iter().zip(ys.iter()).enumerate() {
let mut img = img0.to_rgb8();
// draw for classifier
if let Some(probs) = y.probs() {
for (i, k) in probs.topk(5).iter().enumerate() {
let legend = format!("{} {:.2}%", self.names[k.0], k.1);
let scale = 32;
let legend_size = img.width().max(img.height()) / scale;
let x = img.width() / 20;
let y = img.height() / 20 + i as u32 * legend_size;
imageproc::drawing::draw_text_mut(
&mut img,
image::Rgb([0, 255, 0]),
x as i32,
y as i32,
rusttype::Scale::uniform(legend_size as f32 - 1.),
&font,
&legend,
);
}
}
// draw bboxes & keypoints
if let Some(bboxes) = y.bboxes() {
for (_idx, bbox) in bboxes.iter().enumerate() {
// rect
imageproc::drawing::draw_hollow_rect_mut(
&mut img,
imageproc::rect::Rect::at(bbox.xmin() as i32, bbox.ymin() as i32)
.of_size(bbox.width() as u32, bbox.height() as u32),
image::Rgb(self.color_palette[bbox.id()].into()),
);
// text
let legend = format!("{} {:.2}%", self.names[bbox.id()], bbox.confidence());
let scale = 40;
let legend_size = img.width().max(img.height()) / scale;
imageproc::drawing::draw_text_mut(
&mut img,
image::Rgb(self.color_palette[bbox.id()].into()),
bbox.xmin() as i32,
(bbox.ymin() - legend_size as f32) as i32,
rusttype::Scale::uniform(legend_size as f32 - 1.),
&font,
&legend,
);
}
}
// draw kpts
if let Some(keypoints) = y.keypoints() {
for kpts in keypoints.iter() {
for kpt in kpts.iter() {
// filter
if kpt.confidence() < self.kconf {
continue;
}
// draw point
imageproc::drawing::draw_filled_circle_mut(
&mut img,
(kpt.x() as i32, kpt.y() as i32),
2,
image::Rgb([0, 255, 0]),
);
}
// draw skeleton if has
if let Some(skeletons) = skeletons {
for &(idx1, idx2) in skeletons.iter() {
let kpt1 = &kpts[idx1];
let kpt2 = &kpts[idx2];
if kpt1.confidence() < self.kconf || kpt2.confidence() < self.kconf {
continue;
}
imageproc::drawing::draw_line_segment_mut(
&mut img,
(kpt1.x(), kpt1.y()),
(kpt2.x(), kpt2.y()),
image::Rgb([233, 14, 57]),
);
}
}
}
}
// draw mask
if let Some(masks) = y.masks() {
for (mask, _bbox) in masks.iter().zip(y.bboxes().unwrap().iter()) {
let mask_nd: ImageBuffer<image::Luma<_>, Vec<u8>> =
match ImageBuffer::from_vec(img.width(), img.height(), mask.to_vec()) {
Some(image) => image,
None => panic!("can not crate image from ndarray"),
};
for _x in 0..img.width() {
for _y in 0..img.height() {
let mask_p = imageproc::drawing::Canvas::get_pixel(&mask_nd, _x, _y);
if mask_p.0[0] > 0 {
let mut img_p = imageproc::drawing::Canvas::get_pixel(&img, _x, _y);
// img_p.0[2] = self.color_palette[bbox.id()].2 / 2;
// img_p.0[1] = self.color_palette[bbox.id()].1 / 2;
// img_p.0[0] = self.color_palette[bbox.id()].0 / 2;
img_p.0[2] /= 2;
img_p.0[1] = 255 - (255 - img_p.0[2]) / 2;
img_p.0[0] /= 2;
imageproc::drawing::Canvas::draw_pixel(&mut img, _x, _y, img_p)
}
}
}
}
}
// mkdir and save
let mut runs = PathBuf::from("runs");
if !runs.exists() {
std::fs::create_dir_all(&runs).unwrap();
}
runs.push(gen_time_string("-"));
let saveout = format!("{}.jpg", runs.to_str().unwrap());
let _ = img.save(saveout);
}
}
pub fn summary(&self) {
println!(
"\nSummary:\n\
> Task: {:?}{}\n\
> EP: {:?} {}\n\
> Dtype: {:?}\n\
> Batch: {} ({}), Height: {} ({}), Width: {} ({})\n\
> nc: {} nk: {}, nm: {}, conf: {}, kconf: {}, iou: {}\n\
",
self.task(),
match self.engine.author().zip(self.engine.version()) {
Some((author, ver)) => format!(" ({} {})", author, ver),
None => String::from(""),
},
self.engine.ep(),
if let OrtEP::Cpu = self.engine.ep() {
""
} else {
"(May still fall back to CPU)"
},
self.engine.dtype(),
self.batch(),
if self.engine.is_batch_dynamic() {
"Dynamic"
} else {
"Const"
},
self.height(),
if self.engine.is_height_dynamic() {
"Dynamic"
} else {
"Const"
},
self.width(),
if self.engine.is_width_dynamic() {
"Dynamic"
} else {
"Const"
},
self.nc(),
self.nk(),
self.nm(),
self.conf,
self.kconf,
self.iou,
);
}
pub fn engine(&self) -> &OrtBackend {
&self.engine
}
pub fn conf(&self) -> f32 {
self.conf
}
pub fn set_conf(&mut self, val: f32) {
self.conf = val;
}
pub fn conf_mut(&mut self) -> &mut f32 {
&mut self.conf
}
pub fn kconf(&self) -> f32 {
self.kconf
}
pub fn iou(&self) -> f32 {
self.iou
}
pub fn task(&self) -> &YOLOTask {
&self.task
}
pub fn batch(&self) -> u32 {
self.batch
}
pub fn width(&self) -> u32 {
self.width
}
pub fn height(&self) -> u32 {
self.height
}
pub fn nc(&self) -> u32 {
self.nc
}
pub fn nk(&self) -> u32 {
self.nk
}
pub fn nm(&self) -> u32 {
self.nm
}
pub fn names(&self) -> &Vec<String> {
&self.names
}
}