Hi. I'm trying to convert a torch model into onnx model and then into .dlc and quantize using snpe.
My model contains warping function w/ torch.gather() function, where the argument indices should be of type long().
The code below is the torch model I'm using;
def warp_custom(img, flow):B, C, H, W = img.shapexx = torch.linspace(-1.0, 1.0, W).view(1, 1, 1, W).expand(B, -1, H, -1)yy = torch.linspace(-1.0, 1.0, H).view(1, 1, H, 1).expand(B, -1, -1, W)# Generate the sampling gridxx = torch.linspace(-1.0, 1.0, W).view(1, 1, 1, W).expand(B, -1, H, -1)yy = torch.linspace(-1.0, 1.0, H).view(1, 1, H, 1).expand(B, -1, -1, W)grid = torch.cat([xx, yy], 1)# Scale the flow to match grid dimensions#flow_ = f0.permute(0, 2, 3, 1) # Change flow shape to match grid shapeflow_ = torch.cat([flow[:, 0:1, :, :] / ((W - 1.0) / 2.0), flow[:, 1:2, :, :] / ((H - 1.0) / 2.0)], 1)# Update the grid with flowgrid_ = (grid + flow_)grid_ = grid_.permute(0, 2, 3, 1)# Extract coordinates for manual interpolationy_grid = grid_[:, :, :, 1].clamp(-1, 1) # Clamp y-coordinates within [-1, 1]x_grid = grid_[:, :, :, 0].clamp(-1, 1) # Clamp x-coordinates within [-1, 1]# Convert grid coordinates to pixel coordinatesy_index = ((y_grid + 1) / 2) * (H - 1)x_index = ((x_grid + 1) / 2) * (W - 1)# Compute integer and fractional parts of pixel coordinatesy_index_int = torch.floor(y_index).long()x_index_int = torch.floor(x_index).long()y_index_frac = y_index - y_index_intx_index_frac = x_index - x_index_int# Compute the indices of the four nearest pixelsstride = Wtop_left_indices = y_index_int * stride + x_index_inttop_right_indices = top_left_indices + 1bottom_left_indices = top_left_indices + stridebottom_right_indices = bottom_left_indices + 1# Clamp indices to be within the image boundstop_left_indices = torch.clamp(top_left_indices, 0, H * W - 1)top_right_indices = torch.clamp(top_right_indices, 0, H * W - 1)bottom_left_indices = torch.clamp(bottom_left_indices, 0, H * W - 1)bottom_right_indices = torch.clamp(bottom_right_indices, 0, H * W - 1)img_flat = img.reshape(B, C, -1)top_left = img_flat.gather(2, top_left_indices.reshape(B, 1, -1).expand(-1, C, -1))top_right = img_flat.gather(2, top_right_indices.reshape(B, 1, -1).expand(-1, C, -1))bottom_left = img_flat.gather(2, bottom_left_indices.reshape(B, 1, -1).expand(-1, C, -1))bottom_right = img_flat.gather(2, bottom_right_indices.reshape(B, 1, -1).expand(-1, C, -1))top = top_left * (1 - x_index_frac.reshape(B, -1)) + top_right * x_index_frac.reshape(B, -1)bottom = bottom_left * (1 - x_index_frac.reshape(B, -1)) + bottom_right * x_index_frac.reshape(B, -1)interpolated = top * (1 - y_index_frac.reshape(B, -1)) + bottom * y_index_frac.reshape(B, -1)print(interpolated.shape)return interpolated.reshape(B, C, H, W)
However in conversion process of onnx to .dlc & quantization, the following warning occurs;
240 - WARNING - WARNING_CAST_TYPE: Only numerical type cast is supported. The op: /Cast_31 will be interpreted at conversion time.
With such .dlc model, it runs on cpu but not on gpu/dsp (quantized model just doesn't work)
Is there any possible solution for this?
torch requires .long or .int64 type to be used as indices, but such type casting makes the above warning occur.