Description
This blog post is to revive my blogging habit and post some bits of transreid code that i found interesting. TransReid is a popular paper focused on the concept of reidentification of objects using vision transformer. In this post, we’ll go into the code review.
CHECK OUT THE PAPER - Arxiv Link
I have created this TransReID documented repo
If you look into the repository. You might see structure something like this
.
.
.
.
├── loss
│ ├── arcface.py
│ ├── center_loss.py
│ ├── __init__.py
│ ├── make_loss.py
│ ├── metric_learning.py
│ ├── softmax_loss.py
│ └── triplet_loss.py
├── model
│ ├── backbones
│ │ ├── __init__.py
│ │ ├── resnet.py
│ │ └── vit_pytorch.py
│ ├── __init__.py
│ └── make_model.py
├── processor
│ ├── __init__.py
│ └── processor.py
├── solver
│ ├── cosine_lr.py
│ ├── __init__.py
│ ├── lr_scheduler.py
│ ├── make_optimizer.py
│ ├── scheduler_factory.py
│ └── scheduler.py
├── utils
│ ├── __init__.py
│ ├── iotools.py
│ ├── logger.py
│ ├── meter.py
│ ├── metrics.py
│ └── reranking.py
├── dist_train.sh
├── LICENSE
├── README.md
├── requirements.txt
├── test.py
└── train.py
There’s varity of dataset to pick from, since Market1501
dataset was the smallest, i picked it.
If you are interested, run the following command
1
python train.py --config_file configs/Market/vit_jpm.yml MODEL.DEVICE_ID "('0')"
Code exploration
Let’s understand how train.py
works out.
basically code is majorly divided into several parts.
1) make_dataloader
2) make_model
3) make_loss
4) make_optimizer
5) do_train
- main training loop
Where i will be focusing on make_model
and make_loss
segments. Which i think is the core to their code structure.
make_model
This function before calling does a important step that is creating TransReID
model baseline.
make_model
will go like this :
1
2
3
4
5
6
def func()
# intializes transreid baseline
# adds 4 classifier layer for scores
# adds 4 transformer blocks to get global feature
# and local feature
# later, return scores and features
TransReID
We will look into it in two ways. First is Init
, Second is Inference
.
Setting up
1
2
3
4
5
6
class TransReID(nn.Module):
""" Transformer-based Object Re-Identification
"""
def __init__(self, img_size=224, patch_size=16, stride_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12,
num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0., camera=0, view=0,
drop_path_rate=0., hybrid_backbone=None, norm_layer=nn.LayerNorm, local_feature=False, sie_xishu =1.0):
According to the paper, image is converted into patch embedding with OVERLAPPING patches. Paper mentions it helps with feature extraction which is essential for reidentification.
Simple conv2d layer is used for the purpose. So basically
- Stride = Patch Size: No overlap
- Stride < Patch Size: Overlap between patches
For a 224x224 image with 16x16 patches, stride = 8 (50% overlap) results in 27x27 patches
1
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=stride_size)
Later, SIE embeddings based on the number of cameras and viewpoints are intialized.
Now, most importantly. Transformers blocks are stacked.
1
2
3
4
5
6
self.blocks = nn.ModuleList([
Block(
dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer)
for i in range(depth)])
And weights are initialized with normal distribution function -
1
2
3
4
5
6
7
8
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
Inference
For inference, lets look at the forward_features
.
1
2
3
4
5
6
7
def forward_features(self, x, camera_id, view_id):
# x image
# x = patch embedding(x)
# append class token to x
# append camera and view to x
# dropout
# run down to transformer blocks
Considering x
to be image with tensorsize :
1
torch.Size([16, 3, 256, 128])
Next, when you do the patch embedding. the output is torch.Size([16, 128, 768]
.
Now, after adding class token - torch.Size([16, 129, 768])
the overall vector should look something like this -
1
[cls_token, patch_1, patch_2, ..., patch_n]
Final step is transformer blocks. If local_feature is enabled in TransReID. The last transformer block is avoided.
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
if self.local_feature:
# For tasks requiring local feature context,
# except last block, all the blocks are considered
for blk in self.blocks[:-1]:
# torch.Size([16, 129, 768])
x = blk(x)
return x
else:
# For tasks requiring global context,
# the entire sequence of blocks is processed to leverage the final
print("all features ")
for blk in self.blocks:
x = blk(x)
x = self.norm(x)
return x[:, 0]
TransReID extention
Now with TransReID the baseline is defined. Additional layers are added in build_transformer_local
.
Last block(global feature provider) is taken out from the baseline(TransReID
).
1
2
3
4
5
6
7
8
9
10
11
block = self.base.blocks[-1]
layer_norm = self.base.norm
self.b1 = nn.Sequential(
copy.deepcopy(block),
copy.deepcopy(layer_norm)
)
self.b2 = nn.Sequential(
copy.deepcopy(block),
copy.deepcopy(layer_norm)
)
Further, bottleneck_1
, bottleneck_2
… bottleneck_4
is defined. Which are basically
nn.BatchNorm1d(self.in_planes)
layers.
And to get the scores from several layers of classifier initialized.
1
2
3
4
5
6
7
8
9
10
11
self.classifier = nn.Linear(self.in_planes, self.num_classes, bias=False)
self.classifier.apply(weights_init_classifier)
self.classifier_1 = nn.Linear(self.in_planes, self.num_classes, bias=False)
self.classifier_1.apply(weights_init_classifier)
self.classifier_2 = nn.Linear(self.in_planes, self.num_classes, bias=False)
self.classifier_2.apply(weights_init_classifier)
self.classifier_3 = nn.Linear(self.in_planes, self.num_classes, bias=False)
self.classifier_3.apply(weights_init_classifier)
self.classifier_4 = nn.Linear(self.in_planes, self.num_classes, bias=False)
self.classifier_4.apply(weights_init_classifier)
Inference
The forward function of build_transformer_local
is as follows.
1
2
3
4
5
6
7
def forward(self, x, label=None, cam_label= None, view_label=None):
# inference through transreid baseline
# get the global branch feature using `self.b1`
# apply JPM branch and shuffling
# infer local features from `self.b2` like `b1_local_feat`, `b2_local_feat`...
# run these down through `self.bottleneck_1`, 2,3,4 respectively,
# now get classifier scores.
x
is the image inferes through self.base
TransReID model to get features.
Later when features are applied through JPM branch. According to the paper, this basically shuffles the patches for better reidentification.
4 different set of local features are created with the patch length based on feature_length // self.divide_length
.
Finally, classifier class score is created.
1
2
3
4
5
6
7
8
9
# cls_score - torch.Size([16, 751]
cls_score = self.classifier(feat)
# Classification scores for the local features
# global_feat - torch.Size([16, 768]
# cls_score_1,2,3,4 - torch.Size([16, 751])
cls_score_1 = self.classifier_1(local_feat_1_bn)
cls_score_2 = self.classifier_2(local_feat_2_bn)
cls_score_3 = self.classifier_3(local_feat_3_bn)
cls_score_4 = self.classifier_4(local_feat_4_bn)
And return the classifier score and features in the specific structure.
1
2
3
4
5
return [cls_score, cls_score_1, cls_score_2, cls_score_3,
cls_score_4],
[global_feat, local_feat_1, local_feat_2, local_feat_3,
local_feat_4] # global feature for triplet loss
make_loss
Now, lets understand how does this make_loss
function works. If you look at loss/make_loss.py
file. You would be able to recognize the mess.
Market1501 yml
config file specifies, where loss function is combination of cross entropy and triplet loss.
Their def make_loss
function goes like :
1
2
3
4
5
6
7
8
9
10
11
def make_loss(cfg, num_classes) :
# if config mentions triplet loss. Initialize triplet loss
# if config mentions cross entropy. Initialize cross entropy
# if config mentions `cross_entropy`
# create and return inline function def loss_func(...)
# if sampler config mentions `softmax_triplet`
# create inline function def loss_func(...)
# lots of things going on in def loss_func(...)
what are things going on in def loss_func(…) when sampler is “softmax_triplet”
- Checks if config asks for labelsmoothing on cross entropy or not.
- Applies relevant cross entropy loss also known as ID_LOSS.
- Applies triplet loss
- Takes weighted average of ID_LOSS and Triplet_loss
lets try and understand this part of code snippet -
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
if isinstance(score, list):
print("Score is a list. Calculating ID loss without label smoothing.")
ID_LOSS = [F.cross_entropy(scor, target) for scor in score[1:]]
ID_LOSS = sum(ID_LOSS) / len(ID_LOSS)
ID_LOSS = 0.5 * ID_LOSS + 0.5 * F.cross_entropy(score[0], target)
else:
print("Score is not a list. Calculating ID loss without label smoothing.")
ID_LOSS = F.cross_entropy(score, target)
if isinstance(feat, list):
print("Feat is a list. Calculating triplet loss.")
TRI_LOSS = [triplet(feats, target)[0] for feats in feat[1:]]
TRI_LOSS = sum(TRI_LOSS) / len(TRI_LOSS)
TRI_LOSS = 0.5 * TRI_LOSS + 0.5 * triplet(feat[0], target)[0]
else:
print("Feat is not a list. Calculating triplet loss.")
TRI_LOSS = triplet(feat, target)[0]
If you closely look, you’d be like what are they doing with this -
F.cross_entropy(scor, target) for scor in score[1:]
and F.cross_entropy(score[0], target)
. Similarly in triplet loss.
In our model, we are returning.
1
2
3
4
5
return [cls_score, cls_score_1, cls_score_2, cls_score_3,
cls_score_4],
[global_feat, local_feat_1, local_feat_2, local_feat_3,
local_feat_4] # global feature for triplet loss
So the loss function, seperately calculates specific loss for score[0]
and score[1:]
. Which basically means first we calculate loss for global feature and then loss for local features. Then averaging it.
That’s all.
Why do they do it? Probably to balance out the influence of local and global feature during training equally.
-
Previous
[Paper review] CLIP2Scene: Towards Label-efficient 3D Scene Understanding -
Next
[Paper notes] Fine Tuning Large Vision-language Models as Decision-Making Agents via Reinforcement Learning