*** Wartungsfenster jeden ersten Mittwoch vormittag im Monat ***

Skip to content
Snippets Groups Projects
Commit 6fbfb8cf authored by St333fan's avatar St333fan
Browse files

finalised working project

parent a393a531
No related branches found
No related tags found
1 merge request!5Feature
assets/croco_pipeline.png

4.16 MiB | W: | H:

assets/croco_pipeline.png

4.23 MiB | W: | H:

assets/croco_pipeline.png
assets/croco_pipeline.png
assets/croco_pipeline.png
assets/croco_pipeline.png
  • 2-up
  • Swipe
  • Onion skin
......@@ -141,7 +141,7 @@ def calculate_advanced_similarity(img1, img2):
'warp_similarity': warp_sim*10
}
def process_image(model, image_path, ref_image, device, trfs, imagenet_mean_tensor, imagenet_std_tensor, mask_array):
def process_image(model, image_path, segmented_image, device, trfs, imagenet_mean_tensor, imagenet_std_tensor, mask_array):
"""
Process an image using a given model and compare it to a reference image.
......@@ -151,7 +151,7 @@ def process_image(model, image_path, ref_image, device, trfs, imagenet_mean_tens
Args:
model: torch.nn.Module, The neural network model to use for processing.
image_path: str, Path to the image file to be processed.
ref_image: torch.Tensor, The reference image tensor.
segmented_image: torch.Tensor, The reference image tensor.
device: torch.device, The device (CPU or GPU) to run the computations on.
trfs: torchvision.transforms.Compose, Composition of image transformations to apply.
imagenet_mean_tensor: torch.Tensor, Mean tensor for ImageNet normalization.
......@@ -161,7 +161,7 @@ def process_image(model, image_path, ref_image, device, trfs, imagenet_mean_tens
Returns:
torch.Tensor: The decoded image tensor.
"""
image1 = ref_image # segmented object
image1 = segmented_image # segmented object
image2 = trfs(Image.open(image_path).convert('RGB')).to(device, non_blocking=True).unsqueeze(0) # template image
custom_mask = create_custom_mask(mask_array)
......@@ -182,7 +182,7 @@ def process_image(model, image_path, ref_image, device, trfs, imagenet_mean_tens
return decoded_image
def process(ref_image_path=None, ref_image=None, ckpt_path=None, output_folder=None, assets_folder=None, mask_array=None):
def process(segmented_image_path=None, segmented_image=None, ckpt_path=None, output_folder=None, assets_folder=None, mask_array=None):
"""
Process a set of images using a reference image and a pre-trained model.
......@@ -190,7 +190,7 @@ def process(ref_image_path=None, ref_image=None, ckpt_path=None, output_folder=N
and then processes all images in a specified folder, saving the decoded results.
Args:
ref_image_path: str, Path to the reference image file.
segmented_image_path: str, Path to the reference image file.
ckpt_path: str, Path to the checkpoint file containing the pre-trained model.
output_folder: str, Path to the folder where decoded images will be saved.
assets_folder: str, Path to the folder containing images to be processed.
......@@ -207,10 +207,10 @@ def process(ref_image_path=None, ref_image=None, ckpt_path=None, output_folder=N
trfs = Compose([ToTensor(), Normalize(mean=imagenet_mean, std=imagenet_std),transforms.Resize((224, 224))])
# Load the reference image
if ref_image_path != None:
ref_image = trfs(Image.open(ref_image_path).convert('RGB')).to(device, non_blocking=True).unsqueeze(0)
if segmented_image_path != None:
segmented_image = trfs(Image.open(segmented_image_path).convert('RGB')).to(device, non_blocking=True).unsqueeze(0)
else:
ref_image = trfs(ref_image.convert('RGB')).to(device, non_blocking=True).unsqueeze(0)
segmented_image = trfs(segmented_image.convert('RGB')).to(device, non_blocking=True).unsqueeze(0)
# load model
ckpt = torch.load(ckpt_path, 'cpu')
......@@ -223,7 +223,7 @@ def process(ref_image_path=None, ref_image=None, ckpt_path=None, output_folder=N
for filename in os.listdir(assets_folder):
if filename.lower().endswith(('.png', '.jpg', '.jpeg')):
image_path = os.path.join(assets_folder, filename)
decoded_image = process_image(model, image_path, ref_image, device, trfs, imagenet_mean_tensor,
decoded_image = process_image(model, image_path, segmented_image, device, trfs, imagenet_mean_tensor,
imagenet_std_tensor, mask_array)
# Save the decoded image
......@@ -231,12 +231,12 @@ def process(ref_image_path=None, ref_image=None, ckpt_path=None, output_folder=N
torchvision.utils.save_image(decoded_image, output_path)
print(f'Decoded image saved: {output_path}')
def find_match(ref_image_path=None, ref_image=None, decoded_images_dir=None, mask_array=None):
def find_match(segmented_image_path=None, segmented_image=None, decoded_images_dir=None, mask_array=None):
"""
Match a reference image with several decoded images
Args:
ref_image_path: str, path to the reference image file
segmented_image_path: str, path to the reference image file
decoded_images_dir: str, path to the decoded image files
mask_array: np.Array, used mask
......@@ -294,10 +294,10 @@ def find_match(ref_image_path=None, ref_image=None, decoded_images_dir=None, mas
return ssim_value, mse
# Load images
if ref_image_path!=None:
img1 = cv2.imread(ref_image_path)
if segmented_image_path!=None:
img1 = cv2.imread(segmented_image_path)
else:
img1 = ref_image
img1 = segmented_image
top_10 = []
......@@ -305,7 +305,7 @@ def find_match(ref_image_path=None, ref_image=None, decoded_images_dir=None, mas
for filename in os.listdir(decoded_images_dir):
if filename.endswith(('.png', '.jpg', '.jpeg')):
img_path = os.path.join(decoded_images_dir, filename)
img2 = cv2.imread(img_path)
img2 = cv2.imread(img_path) # decoded images
print(f"Processing: {filename}")
# Apply mask
......@@ -375,20 +375,20 @@ def main():
[1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1]
]
# ref_iamge_path: the segemented object
# segmented_image_path: the segemented object
# ckpt_path: ViT weights
# output_folder: use ./ZS6D/assets_match/decoded_images
# assets_folder: where all the dataset iamges are located
# mask_array: the mask
process(ref_image_path='./test/test_crocom/3.png',
process(segmented_image_path='./test/test_crocom/3.png',
ckpt_path='./pretrained_models/CroCo.pth', #_V2_ViTLarge_BaseDecoder
output_folder='./assets_match/decoded_images',
assets_folder='./templates/ycbv_desc/obj_15', mask_array=mask_array)
# ref_iamge_path: the segemented object
# segmented_image_path: the segemented object
# ref: use ./ZS6D/assets_match/decoded_images in future should be hold in RAM
# mask_array: the mask
best_match = find_match(ref_image_path='/test/test_crocom/3.png',
best_match = find_match(segmented_image_path='/test/test_crocom/3.png',
decoded_images_dir='/home/stefan/PycharmProjects/ZS6D/assets_match/decoded_images',
mask_array=mask_array)
......
......@@ -168,14 +168,14 @@ if __name__=="__main__":
# run whole process or manual checking and extracting
if True:
croco_match.process(ref_image=img_crop,
croco_match.process(segmented_image=img_crop,
ckpt_path='/home/stefan/PycharmProjects/ZS6D/pretrained_models/CroCo_V2_ViTLarge_BaseDecoder.pth',
# _V2_ViTLarge_BaseDecoder
output_folder='/home/stefan/PycharmProjects/ZS6D/assets_match/decoded_images',
assets_folder=assets_folder,
mask_array=mask_array)
best_match = croco_match.find_match(ref_image=img_crop,
best_match = croco_match.find_match(segmented_image=img_crop,
decoded_images_dir='/home/stefan/PycharmProjects/ZS6D/assets_match/decoded_images',
mask_array=mask_array)
......
......@@ -152,14 +152,14 @@ if __name__=="__main__":
assets_folder = '/home/stefan/PycharmProjects/ZS6D/templates/ycbv_desc/'+'obj_'+ str(img_label['obj_id'])
croco_match.process(ref_image=img_crop,
croco_match.process(segmented_image=img_crop,
ckpt_path='/home/stefan/PycharmProjects/ZS6D/pretrained_models/CroCo.pth',
# _V2_ViTLarge_BaseDecoder
output_folder='/home/stefan/PycharmProjects/ZS6D/assets_match/decoded_images',
assets_folder=assets_folder,
mask_array=mask_array)
best_match = croco_match.find_match(ref_image=img_crop,
best_match = croco_match.find_match(segmented_image=img_crop,
decoded_images_dir='/home/stefan/PycharmProjects/ZS6D/assets_match/decoded_images',
mask_array=mask_array)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment