Search before asking
Bug
I found this when trying to use the Deep Faun New England classifier for species classification. As the code is currently written, running the DFNE classifier (and presumambly the Deep Faun classifier as well) causes memory usage to balloon. I did a batch classification of approximately 150 images and memory usage (RAM or GPU RAM, doesn't matter) just keeps growing. I tried running classification on roughly 3000 images and memory usage ballooned to over 300 GB and I didn't even get half way through the set of images.
I believe the issue is in
PytorchWildlife/models/classification/timm_base/base_classifier.py
in the batch_image_classification function specifically these lines of code:
with tqdm(total=len(dataloader)) as pbar:
for batch in dataloader:
imgs, paths = batch
imgs = imgs.to(self.device)
total_logits.append(self.predictor(imgs))
total_paths.append(paths)
pbar.update(1)
I believe the issue is that model is still keeping track of gradient information. When I modified the code to be as follows:
with tqdm(total=len(dataloader)) as pbar:
with torch.no_grad():
for batch in dataloader:
imgs, paths = batch
imgs = imgs.to(self.device)
total_logits.append(self.predictor(imgs))
total_paths.append(paths)
pbar.update(1)
Memory usage was normal and I was able to run the classifier without issue
Environment
No response
Minimal Reproducible Example
No response
Additional
No response
Are you willing to submit a PR?
Search before asking
Bug
I found this when trying to use the Deep Faun New England classifier for species classification. As the code is currently written, running the DFNE classifier (and presumambly the Deep Faun classifier as well) causes memory usage to balloon. I did a batch classification of approximately 150 images and memory usage (RAM or GPU RAM, doesn't matter) just keeps growing. I tried running classification on roughly 3000 images and memory usage ballooned to over 300 GB and I didn't even get half way through the set of images.
I believe the issue is in
PytorchWildlife/models/classification/timm_base/base_classifier.py
in the batch_image_classification function specifically these lines of code:
I believe the issue is that model is still keeping track of gradient information. When I modified the code to be as follows:
Memory usage was normal and I was able to run the classifier without issue
Environment
No response
Minimal Reproducible Example
No response
Additional
No response
Are you willing to submit a PR?