*** Wartungsfenster jeden ersten Mittwoch vormittag im Monat ***

Skip to content
Snippets Groups Projects
Commit 3e54385d authored by St333fan's avatar St333fan
Browse files

added CrocoPoseExtractorClass

parent 1626bca2
No related branches found
No related tags found
No related merge requests found
......@@ -354,9 +354,9 @@ class CroCoExtractor:
if model is not None:
self.model = model
else:
self.model = ViTExtractor.create_model(model_type)
self.model = CroCoExtractor.create_model(model_type)
# self.model = ViTExtractor.patch_vit_resolution(self.model, stride=stride)
#self.model = CroCoExtractor.patch_vit_resolution(self.model, stride=stride)
self.model.eval()
self.model.to(self.device)
self.p = self.model.patch_embed.patch_size
......@@ -458,7 +458,7 @@ class CroCoExtractor:
# fix the stride
model.patch_embed.proj.stride = stride
# fix the positional encoding code
model.interpolate_pos_encoding = types.MethodType(ViTExtractor._fix_pos_enc(patch_size, stride), model)
model.interpolate_pos_encoding = types.MethodType(CroCoExtractor._fix_pos_enc(patch_size, stride), model)
return model
def preprocess(self, image_path: Union[str, Path],
......@@ -547,10 +547,11 @@ class CroCoExtractor:
B, C, H, W = batch.shape
self._feats = []
self._register_hooks(layers, facet)
_ = self.model(batch)
_ = self.model(batch, batch)
self._unregister_hooks()
self.load_size = (H, W)
self.num_patches = (1 + (H - self.p) // self.stride[0], 1 + (W - self.p) // self.stride[1])
self.num_patches = (1 + (H - self.p[0]) // self.stride[0], 1 + (W - self.p[0]) // self.stride[1])
#self.num_patches = (1 + (H - self.p) // self.stride[0], 1 + (W - self.p) // self.stride[1])
return self._feats
def _log_bin(self, x: torch.Tensor, hierarchy: int = 2) -> torch.Tensor:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment