Newer
Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
import matplotlib
matplotlib.use('TkAgg') # Use TkAgg backend
import matplotlib.pyplot as plt
import argparse
import json
import os
import torch
from tqdm import tqdm
import numpy as np
from src.pose_extractor import PoseViTExtractor
from pose_utils.data_utils import ImageContainer_masks
import pose_utils.img_utils as img_utils
from PIL import Image
import cv2
import pose_utils.utils as utils
import pose_utils.vis_utils as vis_utils
import time
import pose_utils.eval_utils as eval_utils
import csv
import logging
import croco_match
# Setup logging
logging.basicConfig(level=logging.INFO, filename="pose_estimation.log",
format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
if __name__=="__main__":
parser = argparse.ArgumentParser(description='Test pose estimation inference on test set')
parser.add_argument('--config_file', default="./zs6d_configs/bop_eval_configs/cfg_ycbv_inference_bop_myset.json")
args = parser.parse_args()
with open(os.path.join(args.config_file), 'r') as f:
config = json.load(f)
# Loading ground truth files:
with open(os.path.join(config['templates_gt_path']), 'r') as f:
templates_gt = json.load(f)
with open(os.path.join(config['gt_path']), 'r') as f:
data_gt = json.load(f)
with open(os.path.join(config['norm_factor_path']), 'r') as f:
norm_factors = json.load(f)
# Set up a results csv file:
csv_file = os.path.join('./results', config['results_file'])
# Column names for the CSV file
headers = ['scene_id', 'im_id', 'obj_id', 'score', 'R', 't', 'time']
# Create a new CSV file and write the headers
with open(csv_file, mode='w', newline='') as csvfile:
csv_writer = csv.writer(csvfile)
csv_writer.writerow(headers)
if config['debug_imgs']:
debug_img_path = os.path.join("./debug_imgs",config['results_file'].split(".csv")[0])
if not os.path.exists(debug_img_path):
os.makedirs(debug_img_path)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
extractor = PoseViTExtractor(model_type='dino_vits8', stride=4, device=device)
print("Loading PoseViTExtractor is done!")
matches = []
print("Processing input images:")
for all_id, img_labels in tqdm(data_gt.items()):
# enter the image which should be checked
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
continue
scene_id = all_id.split("_")[0]
img_id = all_id.split("_")[-1]
# get data and crops for a single image
img_path = os.path.join(config['dataset_path'], img_labels[0]['img_name'].split("./")[-1])
img_name = img_path.split("/")[-1].split(".png")[0]
img = Image.open(img_path)
cam_K = np.array(img_labels[0]['cam_K']).reshape((3,3))
img_data = ImageContainer_masks(img = img,
img_name = img_name,
scene_id = scene_id,
cam_K = cam_K,
crops = [],
descs = [],
x_offsets = [],
y_offsets = [],
obj_names = [],
obj_ids = [],
model_infos = [],
t_gts = [],
R_gts = [],
masks = [])
for obj_index, img_label in enumerate(img_labels):
bbox_gt = img_label[config['bbox_type']]
if bbox_gt[2] == 0 or bbox_gt[3] == 0:
continue
if bbox_gt != [-1,-1,-1,-1]:
img_data.t_gts.append(np.array(img_label['cam_t_m2c']) * config['scale_factor'])
img_data.R_gts.append(np.array(img_label['cam_R_m2c']).reshape((3,3)))
img_data.obj_ids.append(str(img_label['obj_id']))
img_data.model_infos.append(img_label['model_info'])
try:
mask = img_utils.rle_to_mask(img_label['mask_sam'])
mask_3_channel = np.stack([mask] * 3, axis=-1)
bbox = img_utils.get_bounding_box_from_mask(mask)
img_crop, y_offset, x_offset = img_utils.make_quadratic_crop(np.array(img), bbox)
mask_crop,_,_ = img_utils.make_quadratic_crop(mask, bbox)
img_crop = cv2.bitwise_and(img_crop, img_crop, mask=mask_crop)
img_data.crops.append(Image.fromarray(img_crop))
img_prep, img_crop,_ = extractor.preprocess(Image.fromarray(img_crop), load_size=224)
mask_array = [
[0, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 0, 1],
[1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1],
[1, 0, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 0],
[1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1],
[0, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 0, 1],
[1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1],
[1, 0, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 0],
[1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1],
[0, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 0, 1],
[1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1],
[1, 0, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 0],
[1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1],
[0, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 0, 1],
[1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1]
]
assets_folder = '/home/stefan/PycharmProjects/ZS6D/templates/ycbv_desc/'+'obj_'+ str(img_label['obj_id'])
croco_match.process(ref_image=img_crop,
ckpt_path='/home/stefan/PycharmProjects/ZS6D/pretrained_models/CroCo.pth',
# _V2_ViTLarge_BaseDecoder
output_folder='/home/stefan/PycharmProjects/ZS6D/assets_match/decoded_images',
assets_folder=assets_folder,
mask_array=mask_array)
best_match = croco_match.find_match(ref_image=img_crop,
decoded_images_dir='/home/stefan/PycharmProjects/ZS6D/assets_match/decoded_images',
best_match = best_match.replace("decoded_", "")
best_temp = Image.open('/home/stefan/PycharmProjects/ZS6D/templates/ycbv_desc/'+'obj_'+
str(img_label['obj_id']) +'/'+ best_match)
plt.imshow(best_temp)
plt.show()
print(best_match)
matches.append(all_id+'|'+ str(img_label['obj_id']) +'|'+best_match)
img_data.y_offsets.append(y_offset)
img_data.x_offsets.append(x_offset)
img_data.masks.append(mask_3_channel)
except Exception as e:
matches.append(all_id + '|' + str(img_label['obj_id']) + '|' + '000000.png')
print(f"Warning: 'mask_sam' not found or bad defined in img_label. Skipping this iteration.")
logger.warning(f"Loading mask and extracting descriptor failed for img {img_path} and object_id {obj_index}: {e}")
img_data.crops.append(None)
img_data.descs.append(None)
img_data.y_offsets.append(None)
img_data.x_offsets.append(None)
img_data.masks.append(None)