2016-07-17 3 views
1

torchnet compatibility을 추가하기 위해 병렬 데이터 로더를 torch-dataframe에 추가하려고합니다. 나는 tnt.ParallelDatasetIteratorchanged it 그래서를 사용했습니다 :torch.serialize를 두 번 사용하면 토치의 메모리가 부족합니다.

  1. 기본 일괄 배치가 직렬화와 변환됩니다
  2. 배치가 직렬화 스레드 외부에서로드 스레드에서 스레드
  3. 로 전송됩니다 배치 데이터를 텐서로 변환
  4. tnt.Engine 설정과 일치 시키려면 텐서가 inputtarget 키가있는 테이블에 반환됩니다.

문제enque가 오류로 불리는 두 번째 발생 .../torch_distro/install/bin/luajit: not enough memory. 나는 현재 mnist과 함께 일하고 있으며, 적응 된 번호는 mnist-example입니다. enque 루프는 이제 (메모리 디버깅 출력을 포함)과 같습니다 : 나는 collectgarbage을 뿌려 또한 필요하지 않은 물건을 제거하기 위해 노력했습니다

-- `samplePlaceholder` stands in for samples which have been 
-- filtered out by the `filter` function 
local samplePlaceholder = {} 

-- The enque does the main loop 
local idx = 1 
local function enqueue() 
    while idx <= size and threads:acceptsjob() do 
    local batch, reset = self.dataset:get_batch(batch_size) 

    if (reset) then 
     idx = size + 1 
    else 
     idx = idx + 1 
    end 

    if (batch) then 
     local serialized_batch = torch.serialize(batch) 

     -- In the parallel section only the to_tensor is run in parallel 
     -- this should though be the computationally expensive operation 
     threads:addjob(
     function(argList) 
      io.stderr:write("\n Start"); 
      io.stderr:write("\n 1: " ..tostring(collectgarbage("count"))) 
      local origIdx, serialized_batch, samplePlaceholder = unpack(argList) 

      io.stderr:write("\n 2: " ..tostring(collectgarbage("count"))) 
      local batch = torch.deserialize(serialized_batch) 
      serialized_batch = nil 

      collectgarbage() 
      collectgarbage() 

      io.stderr:write("\n 3: " .. tostring(collectgarbage("count"))) 
      batch = transform(batch) 

      io.stderr:write("\n 4: " .. tostring(collectgarbage("count"))) 
      local sample = samplePlaceholder 
      if (filter(batch)) then 
      sample = {} 
      sample.input, sample.target = batch:to_tensor() 
      end 
      io.stderr:write("\n 5: " ..tostring(collectgarbage("count"))) 

      collectgarbage() 
      collectgarbage() 
      io.stderr:write("\n 6: " ..tostring(collectgarbage("count"))) 

      io.stderr:write("\n End \n"); 
      return { 
      sample, 
      origIdx 
      } 
     end, 
     function(argList) 
      sample, sampleOrigIdx = unpack(argList) 
     end, 
     {idx, serialized_batch, samplePlaceholder} 
    ) 
    end 
    end 
end 

. 메모리 출력이 아니라 정직 :

Start 
1: 374840.87695312 
2: 374840.94433594 
3: 372023.79101562 
4: 372023.85839844 
5: 372075.41308594 
6: 372023.73632812 
End 

enque 루프 기능은 사소한 비 순서화 기능 (메모리 오류 제 enque에서 발생하고있다)이다

iterFunction = function() 
    while threads:hasjob() do 
    enqueue() 
    threads:dojob() 
    if threads:haserror() then 
     threads:synchronize() 
    end 
    enqueue() 

    if table.exact_length(sample) > 0 then 
     return sample 
    end 
    end 
end 

답변

1

그래서 문제는 torch.serialize이었습니다. 설정에서 함수가 전체 데이터 집합을 함수에 결합 시켰습니다. 추가 할 때 :

serialized_batch = nil 
collectgarbage() 
collectgarbage() 

문제가 해결되었습니다. 더 많은 공간을 차지하고있는 것이 무엇인지 알기를 원했고, 함수와 얽혀 큰 데이터 세트가있는 환경에서 함수를 정의하고 크기를 대폭 늘린 것이 그 원인이었습니다. 여기

mnist = require 'mnist' 
local dataset = mnist[mode .. 'dataset']() 

-- PROBLEMATIC LINE BELOW -- 
local ext_resource = dataset.data:reshape(dataset.data:size(1), 
    dataset.data:size(2) * dataset.data:size(3)):double() 

-- Create a Dataframe with the label. The actual images will be loaded 
-- as an external resource 
local df = Dataframe(
    Df_Dict{ 
    label = dataset.label:totable(), 
    row_id = torch.range(1, dataset.data:size(1)):totable() 
    }) 

-- Since the mnist package already has taken care of the data 
-- splitting we create a single subsetter 
df:create_subsets{ 
    subsets = Df_Dict{core = 1}, 
    class_args = Df_Tbl({ 
    batch_args = Df_Tbl({ 
     label = Df_Array("label"), 
     data = function(row) 
     return ext_resource[row.row_id] 
     end 
    }) 
    }) 
} 

그것이 내가 강조 라인을 제거하는 0.0008 메가까지 358 메가에서 메모리 사용량을 줄일 수 있음을 밝혀 지역 데이터의 원래 정의! 나는 성능을 테스트하기 위해 사용되는 코드는했다 :

local mem = {} 
table.insert(mem, collectgarbage("count")) 

local ser_data = torch.serialize(batch.dataset) 
table.insert(mem, collectgarbage("count")) 

local ser_retriever = torch.serialize(batch.batchframe_defaults.data) 
table.insert(mem, collectgarbage("count")) 

local ser_raw_retriever = torch.serialize(function(row) 
    return ext_resource[row.row_id] 
end) 
table.insert(mem, collectgarbage("count")) 

local serialized_batch = torch.serialize(batch) 
table.insert(mem, collectgarbage("count")) 

for i=2,#mem do 
    print(i-1, (mem[i] - mem[i-1])/1024) 
end 

는 원래 출력 생산 어떤 :

1 0.0094480514526367 
2 0.00080204010009766 
3 0.00090408325195312 
4 0.010146141052246 

내가 대한 setfenv를 사용하여 시도 :

1 0.0082607269287109 
2 358.23344707489 
3 0.0017471313476562 
4 358.90182781219 

및 수정 후 함수를 호출했지만 문제가 해결되지 않았습니다. 직렬화 된 데이터를 쓰레드로 전송하는 데 여전히 성능상의 불이익이 있지만 주된 문제가 해결되고 값 비싼 데이터 검색자가 없으면 기능이 상당히 작아집니다.

관련 문제