*** Wartungsfenster jeden ersten Mittwoch vormittag im Monat ***

Skip to content
Snippets Groups Projects
test_zs6d.py 2.66 KiB
Newer Older
St333fan's avatar
St333fan committed
from zs6d import ZS6D
import os
import json
St333fan's avatar
St333fan committed
from croco.models.croco import CroCoNet
St333fan's avatar
St333fan committed
import cv2
from PIL import Image
import pose_utils.img_utils as img_utils
import pose_utils.vis_utils as vis_utils
import numpy as np
import time
import matplotlib.pyplot as plt
St333fan's avatar
St333fan committed
import torch
St333fan's avatar
St333fan committed
from croco.models.croco import CroCoNet
import sys
import random
St333fan's avatar
St333fan committed
sys.path.append("croco")

#ckpt = torch.load('/home/imw-mmi/PycharmProjects/ZS6D/pretrained_models/CroCo.pth')
#model = CroCoNet(**ckpt.get('croco_kwargs', {}))
# setting a seed so the model does not behave random
seed = 1  # found by checking the saliency map
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
random.seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

St333fan's avatar
St333fan committed

# Loading the config file:
with open(os.path.join("./zs6d_configs/bop_eval_configs/cfg_ycbv_inference_bop.json"), "r") as f:
    config = json.load(f)

# Instantiating the pose estimator:
# This involves handing over the path to the templates_gt file and the corresponding object norm_factors.
pose_estimator = ZS6D(config['templates_gt_path'], config['norm_factor_path'], model_type='crocov1', subset_templates=5, max_crop_size=80, stride=16) #max_crop_size=80, stride=4)
St333fan's avatar
St333fan committed

# Loading a ground truth file to access segmentation masks to test zs6d:
with open(os.path.join(config['gt_path']), 'r') as f:
    data_gt = json.load(f)

img_id = '000048_1'
#img_id = '8'

for i in range(len(data_gt[img_id])):
    obj_number = i
    obj_id = data_gt[img_id][obj_number]['obj_id']
    cam_K = np.array(data_gt[img_id][obj_number]['cam_K']).reshape((3, 3))
    bbox = data_gt[img_id][obj_number]['bbox_visib']

    img_path = os.path.join(config['dataset_path'], data_gt[img_id][obj_number]['img_name'].split("./")[-1])

St333fan's avatar
St333fan committed
    img = Image.open('/home/stefan/PycharmProjects/ZS6D/test/000001.png')
St333fan's avatar
St333fan committed

    mask = data_gt[img_id][obj_number]['mask_sam']
    mask = img_utils.rle_to_mask(mask)
    mask = mask.astype(np.uint8)

    start_time = time.time()

    # To estimate the objects Rotation R and translation t the input image, the object_id, a segmentation mask and camera matrix are necessary
    R_est, t_est = pose_estimator.get_pose(img, str(obj_id), mask, cam_K, bbox=None)
St333fan's avatar
St333fan committed
    if R_est == None:
        print("Failed to find R_est... ")
        continue
St333fan's avatar
St333fan committed
    end_time = time.time()

    out_img = vis_utils.draw_3D_bbox_on_image(np.array(img), R_est, t_est, cam_K,
                                              data_gt[img_id][obj_number]['model_info'], factor=1.0)

    plt.imshow(out_img)
    plt.show()
    print(f"Pose estimation time: {end_time - start_time}")
    print(f"R_est: {R_est}")
    print(f"t_est: {t_est}")