[SOLVED] Problem with nested network on pytorch .”TypeError: forward() missing 1 required positional argument: ‘x'”

Issue

This Content is from Stack Overflow. Question asked by JuanMuñoz

I attempt to create an architecture consisting of one convolutional filter and one layer of three convolutional filters. I first build the inner layer with the name “MysmallNet(nn.module)”, and then I build “MybigNet” calling the small network. This is my code.

#In[]
class MysmallNet(nn.Module):
    def __init__(self):
        super(MysmallNet, self).__init__()
        # TODO Task 3: Design Your Network
        self.Convlayer_1 = nn.Conv2d(in_channels = 16, out_channels = 16, kernel_size = 3, stride = 1,padding=1)
        self.Convlayer_2 = nn.Conv2d(in_channels=16,out_channels=16,kernel_size=3,stride=1, padding=1)
        self.Convlayer_3 = nn.Conv2d(in_channels=16,out_channels=16,kernel_size=3,stride=1, padding=1)
        
    def forward(self, x):
        # TODO Task 3: Design Your Network
        residual1 = x
        x = self.Convlayer_1(x)
        x = self.Convlayer_2(x)
        x = self.Convlayer_3(x)
        return x

MysmallNetV2= MysmallNet()

class MybigNet(nn.Module):
    def __init__(self):
        super(MybigNet, self).__init__()

        self.Convlayer_1 = nn.Conv2d(in_channels=3, out_channels=16, kernel_size=3,stride=1,padding=1)
        self.smallNet= MysmallNetV2()

    def forward(self, x):
        x = self.Convlayer_1(x)
        x = self.smallNet(x)
        return x

modelBig = MybigNet()

I have the issue when I save my model as “modelBig”. The displayed error is :

TypeError: forward() missing 1 required positional argument: 'x'



Solution

Your definition of big net is wrong, it should be:

class MybigNet(nn.Module):
    def __init__(self):
        super(MybigNet, self).__init__()

        self.Convlayer_1 = nn.Conv2d(in_channels=3, out_channels=16, kernel_size=3,stride=1,padding=1)
        self.smallNet= MysmallNet()

    def forward(self, x):
        x = self.Convlayer_1(x)
        x = self.smallNet(x)
        return x

This should solve the issue.


This Question was asked in StackOverflow by JuanMuñoz and Answered by Azhan Mohammed It is licensed under the terms of CC BY-SA 2.5. - CC BY-SA 3.0. - CC BY-SA 4.0.

people found this article helpful. What about you?