2017-05-23 1 views

답변

5

잘만되면 너무 늦지 않았습니다. 난 당신이 유지하고자하는 부분 완전히 확실하지 않다,하지만 당신은 그것의 사전 교육을받은 모델 또는 재사용 부품을 조작하려면, 내 방법은 다음과 같습니다

  1. 가 사전 다운로드를 예 :

    import torch 
    import torch.nn as nn 
    from torchvision import models 
    
    model = models.resnet101(pretrained=True) 
    
  2. 관심있는 부분을 추출하고이 부분에서 새 모델을 만듭니다. 물론

    list(model.modules()) # to inspect the modules of your model 
    my_model = nn.Sequential(*list(model.modules())[:-1]) # strips off last linear layer 
    

, 당신은 추출 할 수 있습니다 그냥 따라 목록을 수정, 당신은 할 수 있습니다 모델의 부품을 재사용 할뿐만 아니라 새로운 모듈을 추가 할 수 있습니다.

+0

나는 그가 자신의 해당 CONV 층에 resnet 같은 네트워크의 FC 레이어를 변환하고자하는 생각을 – Curious

4

당신은 다음과 같이 (설명 주석 참조) 그렇게 할 수 있습니다

import torch 
import torch.nn as nn 
from torchvision import models 

# 1. LOAD PRE-TRAINED VGG16 
model = models.vgg16(pretrained=True) 

# 2. GET CONV LAYERS 
features = model.features 

# 3. GET FULLY CONNECTED LAYERS 
fcLayers = nn.Sequential(
    # stop at last layer 
    *list(model.classifier.children())[:-1] 
) 

# 4. CONVERT FULLY CONNECTED LAYERS TO CONVOLUTIONAL LAYERS 

### convert first fc layer to conv layer with 512x7x7 kernel 
fc = fcLayers[0].state_dict() 
in_ch = 512 
out_ch = fc["weight"].size(0) 

firstConv = nn.Conv2d(in_ch, out_ch, 7, 7) 

### get the weights from the fc layer 
firstConv.load_state_dict({"weight":fc["weight"].view(out_ch, in_ch, 7, 7), 
          "bias":fc["bias"]}) 

# CREATE A LIST OF CONVS 
convList = [firstConv] 

# Similarly convert the remaining linear layers to conv layers 
for layer in enumerate(fcLayers[1:]): 
    if isinstance(module, nn.Linear): 
     # Convert the nn.Linear to nn.Conv 
     fc = module.state_dict() 
     in_ch = fc["weight"].size(1) 
     out_ch = fc["weight"].size(0) 
     conv = nn.Conv2d(in_ch, out_ch, 1, 1) 

     conv.load_state_dict({"weight":fc["weight"].view(out_ch, in_ch, 1, 1), 
      "bias":fc["bias"]}) 

     convList += [conv] 
    else: 
     # Append other layers such as ReLU and Dropout 
     convList += [layer] 

# Set the conv layers as a nn.Sequential module 
convLayers = nn.Sequential(*convList) 
관련 문제