*** Wartungsfenster jeden ersten Mittwoch vormittag im Monat ***

Skip to content
Snippets Groups Projects
Commit 58f775bf authored by St333fan's avatar St333fan
Browse files

first real testable, and it does not work

parent 047feae6
No related branches found
No related tags found
1 merge request!1Feature
......@@ -18,7 +18,7 @@ import random
if __name__=="__main__":
# setting a seed so the model does not behave random
seed = 33 # found by checking the saliency map
seed = 1 # found by checking the saliency map 33
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
......@@ -127,7 +127,7 @@ if __name__=="__main__":
img_crop, crop_x, crop_y = img_utils.make_quadratic_crop(img, [x, y, w, h])
img_prep, img_crop, _ = extractor.preprocess(Image.fromarray(img_crop), load_size=224)
desc = extractor.extract_descriptors(img_prep.to(device), layer=11, facet='key', bin=False, include_cls=True)
desc = extractor.extract_descriptors(img_prep.to(device), layer=11, facet='attn', bin=False, include_cls=True)
desc = desc.squeeze(0).squeeze(0).detach().cpu().numpy()
R = obj_poses[i][:3,:3]
......
......@@ -346,7 +346,7 @@ class CroCoExtractor:
d - the embedding dimension in the ViT.
"""
def __init__(self, model_type: str = 'dino_vits8', stride: int = 4, model: nn.Module = None, device: str = 'cuda'):
def __init__(self, model_type: str = 'dino_vits8', stride: int = 16, model: nn.Module = None, device: str = 'cuda'):
"""
:param model_type: A string specifying the type of model to extract from.
[dino_vits8 | dino_vits16 | dino_vitb8 | dino_vitb16 | vit_small_patch8_224 |
......@@ -361,6 +361,7 @@ class CroCoExtractor:
self.model = model
else:
self.model = CroCoExtractor.create_model(model_type)
'''
seed = 33 # found by checking the saliency map
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
......@@ -369,9 +370,11 @@ class CroCoExtractor:
random.seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
#self.model = CroCoExtractor.patch_vit_resolution(self.model, stride=stride)
'''
if model_type == 'crocov2':
self.model = CroCoExtractor.patch_vit_resolution(self.model, stride=stride)
elif stride != 16:
print('patch_vit_resolution only for crocov2')
self.model.eval()
self.model.to(self.device)
self.p = self.model.patch_embed.patch_size
......@@ -500,13 +503,27 @@ class CroCoExtractor:
(2) the pil image in relevant dimensions
"""
pil_image = Image.open(image_path).convert('RGB')
if load_size is not None:
pil_image = transforms.Resize(load_size, interpolation=transforms.InterpolationMode.LANCZOS)(pil_image)
prep = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=self.mean, std=self.std)
])
prep_img = prep(pil_image)[None, ...]
#prep = transforms.Compose([
# transforms.ToTensor(),
# transforms.Normalize(mean=self.mean, std=self.std)
#])
#prep_img = prep(pil_image)[None, ...]
# Convert the image to a tensor
to_tensor = transforms.ToTensor()
image_tensor = to_tensor(pil_image)[None, ...] # Add batch dimension
# Pad and resize the image tensor
padded_and_resized_tensor = pad_and_resize(image_tensor)
# Normalize the image tensor
normalize = transforms.Normalize(mean=self.mean, std=self.std)
prep_img = padded_and_resized_tensor#normalize(padded_and_resized_tensor)
return prep_img, pil_image
def _get_hook(self, facet: str):
......@@ -532,7 +549,7 @@ class CroCoExtractor:
input = input[0]
B, N, C = input.shape
qkv = module.qkv(input).reshape(B, N, 3, module.num_heads, C // module.num_heads).permute(2, 0, 3, 1, 4)
# qkv = module.projk(input).reshape(B, N, 3, module.num_heads, C // module.num_heads).permute(2, 0, 3, 1, 4)
#qkv = module.proj(input).reshape(B, N, 3, module.num_heads, C // module.num_heads).permute(2, 0, 3, 1, 4)
self._feats.append(qkv[facet_idx]) # Bxhxtxd
return _inner_hook
......@@ -648,12 +665,13 @@ class CroCoExtractor:
:param bin: apply log binning to the descriptor. default is False.
:return: tensor of descriptors. Bx1xtxd' where d' is the dimension of the descriptors.
"""
assert facet in ['key', 'query', 'value', 'token'], f"""{facet} is not a supported facet for descriptors.
assert facet in ['key', 'query', 'value', 'token', 'attn'], f"""{facet} is not a supported facet for descriptors.
choose from ['key' | 'query' | 'value' | 'token'] """
self._extract_features(batch, batch2=None, layers=[layer], facet=facet)
x = self._feats[0]# 1 16 196 31 0--for mono 1--for binocular
if facet == 'token':
x.unsqueeze_(dim=1) # Bx1xtxd
#x.unsqueeze_(dim=1) # Bx1xtxd
x = x.clone().unsqueeze(dim=1) # Bx1xtxd
if not include_cls:
x = x[:, :, 1:, :] # remove cls token
'''
......@@ -727,7 +745,26 @@ def pad_and_resize(image_batch):
# Resize the padded image to 224x224 using bilinear interpolation
resized_image_batch = torch.nn.functional.interpolate(padded_image_batch, size=(224, 224), mode='bilinear', align_corners=False)
# assume 'tensor' is your image tensor with shape (1, 3, 224, 224)
'''
# Convert the PyTorch tensor to a NumPy array
# Ensure the tensor is on the CPU before converting
image_numpy = resized_image_batch.cpu().numpy()
# Remove the batch dimension
image_numpy = image_numpy[0]
# Transpose the dimensions to (224, 224, 3)
image_numpy = np.transpose(image_numpy, (1, 2, 0))
# Clip the values to [0, 1] range, if necessary
image_numpy = np.clip(image_numpy, 0, 1)
# Convert to 8-bit per channel (0-255) and save as PNG
image_numpy = (image_numpy * 255).astype(np.uint8)
image = Image.fromarray(image_numpy)
image.save('image_float.png')
'''
return resized_image_batch
if __name__ == "__main__":
......@@ -768,11 +805,6 @@ if __name__ == "__main__":
image_batch_croco1, image_pil_croco = extractor_croco.preprocess('/home/stefan/PycharmProjects/ZS6D/test/000248.png', args.load_size)
image_batch_croco2, image_pil_croco2 = extractor_croco.preprocess('/home/stefan/PycharmProjects/ZS6D/test/000392.png', args.load_size)
# visualize
#channel = image_batch_croco[0, 0, :, :]
#plt.imsave('channel_0.png', channel, cmap='gray')
# Resize the tensor to (1, 3, 224, 224) using bilinear interpolation
image_batch_croco1 = pad_and_resize(image_batch_croco1)
image_batch_croco2 = pad_and_resize(image_batch_croco2)
......
......@@ -3,11 +3,12 @@ import torch
import numpy as np
import random
import matplotlib
matplotlib.use('Agg')
# Set the backend for matplotlib
matplotlib.use('TkAgg') # Use TkAgg backend
import matplotlib.pyplot as plt
# setting a seed so the model does not behave random
seed = 33 # found by checking the saliency map 33
seed = 3# found by checking the saliency map 33
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
......@@ -21,14 +22,20 @@ with torch.no_grad():
extractor_croco = CroCoExtractor(model_type='crocov1', stride=16, device=device) # stride 16
#extractor_croco = ViTExtractor(device=device)
image_batch_croco1, image_pil_croco = extractor_croco.preprocess(
'/home/stefan/PycharmProjects/ZS6D/test/000392.png', 224) #000248.png
'/home/stefan/Desktop/worst_dose.png', 224)#, mask=False) #000248.png
#image_batch_croco1, image_pil_croco = extractor_croco.preprocess(
# '/home/stefan/PycharmProjects/ZS6D/test/000392.png', 224, mask=False) # 000248.png
image_batch_croco2, image_pil_croco2 = extractor_croco.preprocess(
'/home/stefan/PycharmProjects/ZS6D/test/maskcutbetter.png', 224)#000392.png
'/home/stefan/Desktop/cut_dose.png', 224)#, mask=False)#000392.png
# Remove the batch dimension and move channels to the end
image_array = image_batch_croco1.squeeze(0).permute(1, 2, 0).numpy()
image_batch_croco1 = pad_and_resize(image_batch_croco1)
image_batch_croco2 = pad_and_resize(image_batch_croco2)
# Display the image
plt.imshow(image_array)
plt.axis('off')
plt.show()
strings = ["key", "value", "query"]
strings = ["key", "value", "query", 'token', 'attn']
# Loop over the list and assign each string to the respective variable
for i, facet in enumerate(strings):
......@@ -49,15 +56,20 @@ with torch.no_grad():
norm_tensor1 = torch.norm(descriptors1_2d, dim=1)
norm_tensor2 = torch.norm(descriptors2_2d, dim=1)
# Save the channel as a grayscale image
plt.imsave('descriptor1.png', descriptors1_2d)
plt.imsave('descriptor2.png', descriptors2_2d)
# cosine similarity
cosine_similarity.append(dot_product / (norm_tensor1 * norm_tensor2))
import torch.nn.functional as F
# mean cosine similarity
mean_cosine_similarity.append(torch.mean(dot_product / (norm_tensor1 * norm_tensor2)))
# Create a range for the x-axis (from 0 to 10)
x_values = range(12)
print(mean_cosine_similarity)
# Create the line plot
plt.plot(x_values, mean_cosine_similarity, label=facet.capitalize())
......@@ -69,4 +81,3 @@ with torch.no_grad():
# Save the plot as a PNG image
plt.legend()
plt.savefig('line_plot3.png')
......@@ -18,7 +18,7 @@ sys.path.append("croco")
#ckpt = torch.load('/home/imw-mmi/PycharmProjects/ZS6D/pretrained_models/CroCo.pth')
#model = CroCoNet(**ckpt.get('croco_kwargs', {}))
# setting a seed so the model does not behave random
seed = 33 # found by checking the saliency map
seed = 1 # found by checking the saliency map
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
......@@ -56,7 +56,6 @@ for i in range(len(data_gt[img_id])):
mask = data_gt[img_id][obj_number]['mask_sam']
mask = img_utils.rle_to_mask(mask)
mask = mask.astype(np.uint8)
cv2.imwrite('mask.png', mask)
start_time = time.time()
......
......@@ -66,30 +66,25 @@ class ZS6D:
img_crop, y_offset, x_offset = img_utils.make_quadratic_crop(np.array(img), bbox)
mask_crop, _, _ = img_utils.make_quadratic_crop(mask, bbox)
plt.imshow(mask_crop)
plt.show() # non-blocking show
img_crop = cv2.bitwise_and(img_crop, img_crop, mask=mask_crop)
import random
#filename = f"img_crop_{random.randint(0, 1000)}.png"
#img_crop = cv2.cvtColor(img_crop, cv2.COLOR_BGR2RGB) # Convert to RGB
#cv2.imwrite(filename, img_crop)
img_crop = Image.fromarray(img_crop)
img_prep, _, _ = self.extractor.preprocess(img_crop, load_size=224)
channel = img_prep[0, 0, :, :]
# Save the channel as a grayscale image
plt.imsave('channel_0.png', channel, cmap='gray')
if self.model_type != 'crocov1':
with torch.no_grad():
desc = self.extractor.extract_descriptors(img_prep.to(self.device), layer=11, facet='key',
desc = self.extractor.extract_descriptors(img_prep.to(self.device), layer=11, facet='attn',
bin=False, include_cls=True)
desc = desc.squeeze(0).squeeze(0).detach().cpu()
else:
#img_prep = torch.nn.functional.interpolate(img_prep, size=(224, 224), mode='bilinear',
# align_corners=False)
channel = img_prep[0, 0, :, :]
# Save the channel as a grayscale image
plt.imsave('channel_1.png', channel, cmap='gray')
with torch.no_grad():
desc = self.extractor.extract_descriptors(img_prep.to(self.device), layer=11, facet='key',
desc = self.extractor.extract_descriptors(img_prep.to(self.device), layer=11, facet='attn',
bin=False, include_cls=True)
desc = desc.squeeze(0).squeeze(0).detach().cpu()
......@@ -99,7 +94,11 @@ class ZS6D:
raise ValueError("No matched templates found for the object.")
template = Image.open(self.templates_gt[obj_id][matched_templates[0][1]]['img_crop'])
template.save('/home/stefan/PycharmProjects/ZS6D/img.jpg')
template.save(f'/home/stefan/PycharmProjects/ZS6D/template_{random.randint(0, 1000)}.jpg')
#filename = f"template_{random.randint(0, 1000)}.png"
#cv2.imwrite(filename, template)
plt.imshow(template)
plt.show() # non-blocking show
with torch.no_grad():
if img_crop.size[0] < self.max_crop_size:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment