Flatten

CLASS Mytorch.nn.Flatten(module_name,inputs) 将tensor拉成一维


参数:

  • moudel_name
  • inputs:Tensor

shape:(batch_size, -1)


import numpy as np
from .module import Module
from ..tensor import *


class Flatten(Module):
    def __init__(self, module_name: str):
        super(Flatten, self).__init__(module_name)

    def forward(self, inputs: Tensor):
        assert isinstance(inputs, Tensor)
        batch_size = inputs.datas.shape[0]
        inputs = Tensor.reshape(inputs, (batch_size, -1))
        return inputs