*** Wartungsfenster jeden ersten Mittwoch vormittag im Monat ***

Skip to content
Snippets Groups Projects
test_croco_layers.py 3 KiB
Newer Older
from src.extractor import CroCoExtractor, pad_and_resize, ViTExtractor
St333fan's avatar
St333fan committed
import torch
import numpy as np
import random
import matplotlib
# Set the backend for matplotlib
matplotlib.use('TkAgg')  # Use TkAgg backend
St333fan's avatar
St333fan committed
import matplotlib.pyplot as plt

# setting a seed so the model does not behave random
St333fan's avatar
St333fan committed
seed = 1
St333fan's avatar
St333fan committed
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
random.seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

with torch.no_grad():
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
St333fan's avatar
St333fan committed
    extractor_croco = CroCoExtractor(model_type='crocov1', stride=16, device=device)
    # extractor_croco = ViTExtractor(device=device)
St333fan's avatar
St333fan committed
    image_batch_croco1, image_pil_croco = extractor_croco.preprocess(
St333fan's avatar
St333fan committed
        '/home/stefan/Desktop/000347.png', 224)
St333fan's avatar
St333fan committed
    image_batch_croco2, image_pil_croco2 = extractor_croco.preprocess(
St333fan's avatar
St333fan committed
        '/home/stefan/Desktop/bechertest1.png', 224)

    # Remove the batch dimension and move channels to the end
    image_array = image_batch_croco1.squeeze(0).permute(1, 2, 0).numpy()
St333fan's avatar
St333fan committed

    # Display the image
    plt.imshow(image_array)
    plt.axis('off')
    plt.show()
St333fan's avatar
St333fan committed

    strings = ["key", "value", "query", 'token', 'attn']
St333fan's avatar
St333fan committed

    # Loop over the list and assign each string to the respective variable
    for i, facet in enumerate(strings):
        cosine_similarity = []
        mean_cosine_similarity = []
        for layer in range(12):
            descriptors1 = extractor_croco.extract_descriptors(image_batch_croco1.to(device), layer=layer, facet=facet, bin=False,
                                                               include_cls=True)
            descriptors1_2d = descriptors1.cpu().squeeze(0).squeeze(0)
            descriptors2 = extractor_croco.extract_descriptors(image_batch_croco2.to(device), layer=layer, facet=facet, bin=False,
                                                               include_cls=True)
            descriptors2_2d = descriptors2.cpu().squeeze(0).squeeze(0)

            # dot product of descriptors
            dot_product = torch.sum(descriptors1_2d * descriptors2_2d, dim=1)

            # norm tensor
            norm_tensor1 = torch.norm(descriptors1_2d, dim=1)
            norm_tensor2 = torch.norm(descriptors2_2d, dim=1)

St333fan's avatar
St333fan committed
            # Save the descs as images
            plt.imsave('descriptor1.png', descriptors1_2d)
            plt.imsave('descriptor2.png', descriptors2_2d)

St333fan's avatar
St333fan committed
            # cosine similarity
            cosine_similarity.append(dot_product / (norm_tensor1 * norm_tensor2))

            # mean cosine similarity
            mean_cosine_similarity.append(torch.mean(dot_product / (norm_tensor1 * norm_tensor2)))

St333fan's avatar
St333fan committed
        # range x-axis (from 0 to 11)
St333fan's avatar
St333fan committed
        x_values = range(12)
        print(mean_cosine_similarity)
St333fan's avatar
St333fan committed
        plt.plot(x_values, mean_cosine_similarity, label=facet.capitalize())
        plt.xlabel('Index')
        plt.ylabel('Value')
        plt.title('Line Plot of Values')

St333fan's avatar
St333fan committed
    # Save the plot
St333fan's avatar
St333fan committed
    plt.legend()
St333fan's avatar
St333fan committed
    plt.savefig('line_plot.png')