*** Wartungsfenster jeden ersten Mittwoch vormittag im Monat ***

Skip to content
Snippets Groups Projects
Commit 3e54385d authored by St333fan's avatar St333fan
Browse files

added CrocoPoseExtractorClass

parent 1626bca2
No related branches found
No related tags found
No related merge requests found
...@@ -354,9 +354,9 @@ class CroCoExtractor: ...@@ -354,9 +354,9 @@ class CroCoExtractor:
if model is not None: if model is not None:
self.model = model self.model = model
else: else:
self.model = ViTExtractor.create_model(model_type) self.model = CroCoExtractor.create_model(model_type)
# self.model = ViTExtractor.patch_vit_resolution(self.model, stride=stride) #self.model = CroCoExtractor.patch_vit_resolution(self.model, stride=stride)
self.model.eval() self.model.eval()
self.model.to(self.device) self.model.to(self.device)
self.p = self.model.patch_embed.patch_size self.p = self.model.patch_embed.patch_size
...@@ -458,7 +458,7 @@ class CroCoExtractor: ...@@ -458,7 +458,7 @@ class CroCoExtractor:
# fix the stride # fix the stride
model.patch_embed.proj.stride = stride model.patch_embed.proj.stride = stride
# fix the positional encoding code # fix the positional encoding code
model.interpolate_pos_encoding = types.MethodType(ViTExtractor._fix_pos_enc(patch_size, stride), model) model.interpolate_pos_encoding = types.MethodType(CroCoExtractor._fix_pos_enc(patch_size, stride), model)
return model return model
def preprocess(self, image_path: Union[str, Path], def preprocess(self, image_path: Union[str, Path],
...@@ -547,10 +547,11 @@ class CroCoExtractor: ...@@ -547,10 +547,11 @@ class CroCoExtractor:
B, C, H, W = batch.shape B, C, H, W = batch.shape
self._feats = [] self._feats = []
self._register_hooks(layers, facet) self._register_hooks(layers, facet)
_ = self.model(batch) _ = self.model(batch, batch)
self._unregister_hooks() self._unregister_hooks()
self.load_size = (H, W) self.load_size = (H, W)
self.num_patches = (1 + (H - self.p) // self.stride[0], 1 + (W - self.p) // self.stride[1]) self.num_patches = (1 + (H - self.p[0]) // self.stride[0], 1 + (W - self.p[0]) // self.stride[1])
#self.num_patches = (1 + (H - self.p) // self.stride[0], 1 + (W - self.p) // self.stride[1])
return self._feats return self._feats
def _log_bin(self, x: torch.Tensor, hierarchy: int = 2) -> torch.Tensor: def _log_bin(self, x: torch.Tensor, hierarchy: int = 2) -> torch.Tensor:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment