24 lines
721 B
Python
24 lines
721 B
Python
import cv2
|
|
import numpy as np
|
|
from ultralytics import YOLO
|
|
|
|
IMAGE_PATH = 'data/test.jpg'
|
|
OUTPUT_MASK = 'data/bubble_mask.png'
|
|
|
|
image = cv2.imread(IMAGE_PATH)
|
|
h, w = image.shape[:2]
|
|
mask = np.zeros((h, w), dtype=np.uint8)
|
|
|
|
model = YOLO('model/yolo8_seg-speed-bubble.pt')
|
|
results = model(IMAGE_PATH)
|
|
|
|
for result in results:
|
|
if hasattr(result, 'masks') and result.masks is not None:
|
|
for m in result.masks.data:
|
|
m = m.cpu().numpy().astype(np.uint8) * 255
|
|
# resize mask to image size
|
|
m_resized = cv2.resize(m, (w, h), interpolation=cv2.INTER_NEAREST)
|
|
mask = cv2.bitwise_or(mask, m_resized)
|
|
|
|
cv2.imwrite(OUTPUT_MASK, mask)
|
|
print(f"Đã lưu mask tại {OUTPUT_MASK}") |