GeeeekExplorer 1 rok temu
rodzic
commit
fd011c11aa
1 zmienionych plików z 2 dodań i 3 usunięć
  1. 2 3
      inference/model.py

+ 2 - 3
inference/model.py

@@ -140,13 +140,12 @@ class RowParallelLinear(Linear):
 class RMSNorm(nn.Module):
     def __init__(self, dim: int, eps: float = 1e-6):
         super().__init__()
+        self.dim = dim
         self.eps = eps
         self.weight = nn.Parameter(torch.ones(dim))
 
     def forward(self, x: torch.Tensor):
-        x = x.float()
-        y = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
-        return y.type_as(self.weight) * self.weight
+        return F.rms_norm(x, (self.dim,), self.weight, self.eps)
 
 
 def precompute_freqs_cis(args: ModelArgs) -> torch.Tensor: