かたちづくり

つれづれに、だらだらと、おきらくに

某 C# コードを高速化してみた

aokomoriuta さんが面白いパフォーマンステストをやっていました。aokomoriuta.hateblo.jp

こちらが aokomoriuta さんによる計測結果。

SIMDとかは私は分からないので華麗にスルーして、C# もうちっと速くならんかなー、と思いまして。ゴニョゴニョやった結果、割と高速化出来ました。

結果

  • 高速化前:6秒くらい
  • 高速化後:1秒くらい 1.3秒くらい

※ 当初はX,Y,Zしか計算してなくてWを忘れていたので、修正しました><

やったこと

  • Vector4 を class から struct に変えた。
  • Vector4 の中身も配列をやめて X, Y, Z, W の4変数を持たせるようにした。
  • unsafe コードによって Vector4 の配列をポインタで扱うようにした。
  • ちなみに Vector4 に定義した演算子オーバーロードを利用するとパフォーマンスは劇的に落ちて6.5秒くらい。

感想

書いたコード

※ 当初は W の計算を忘れていたので修正しました。

  struct Vector4
  {
    public double X, Y, Z, W;

    public Vector4( double x, double y, double z, double w = 0.0 )
    {
      this.X = x;
      this.Y = y;
      this.Z = z;
      this.W = w;
    }

    public static Vector4 operator +( Vector4 v1, Vector4 v2 )
    {
      return new Vector4( v1.X + v2.X, v1.Y + v2.Y, v1.Z + v2.Z, v1.W + v2.W );
    }

    public static Vector4 operator *( double scalar, Vector4 v )
    {
      return new Vector4( scalar * v.X, scalar * v.Y, scalar * v.Z, scalar * v.W );
    }
  }

  static class PerformanceTest
  {
    // こっちは遅い
    static unsafe void RunWithOperatorOverload( Vector4* x, Vector4* v, Vector4* f, double m, double dt, int n )
    {
      double tmp = dt * dt / 2;
      double rm = 1.0 / m;
      for ( int i = 0; i < n; i++ ) {
        // a = f/m
        var a = rm * f[i];

        // x += v*dt + a*dt*dt/2
        x[i] += dt * v[i] + tmp * a;

        // v += a*dt
        v[i] += dt * a;
      }
    }

    // こっちは速い
    static unsafe void RunWithoutOperatorOverload( Vector4* x, Vector4* v, Vector4* f, double m, double dt, int n )
    {
      double tmp = dt * dt / 2;
      double rm = 1.0 / m;
      for ( int i = 0; i < n; i++ ) {
        // a = f/m
        var a = new Vector4( rm * f[i].X, rm * f[i].Y, rm * f[i].Z, rm * f[i].W );

        // x += v*dt + a*dt*dt/2
        x[i] = new Vector4(
          x[i].X + dt * v[i].X + tmp * a.X,
          x[i].Y + dt * v[i].Y + tmp * a.Y,
          x[i].Z + dt * v[i].Z + tmp * a.Z,
          x[i].W + dt * v[i].W + tmp * a.W );

        // v += a*dt
        v[i] = new Vector4(
          v[i].X + dt * a.X,
          v[i].Y + dt * a.Y,
          v[i].Z + dt * a.Z,
          v[i].W + dt * a.W );
    }

    static Vector4[] CreateRandomVectors( int n )
    {
      var rand = new System.Random();
      var v = new Vector4[n];
      for ( int i = 0; i < n; ++i ) v[i] = new Vector4( rand.NextDouble(), rand.NextDouble(), rand.NextDouble() );
      return v;
    }

    public static unsafe void Run()
    {
      const int n = 100000;
      const int loop = 1000;
      const double dt = 0.1;
      const double m = 2.5;

      {
        var f = CreateRandomVectors( n );
        var v = CreateRandomVectors( n );
        var x = CreateRandomVectors( n );
        fixed ( Vector4* pX = x, pV = v, pF = f ) {
          Console.Write( "RunWithOperatorOverload: " );
          var stopwatch = new System.Diagnostics.Stopwatch();
          stopwatch.Start();
          for ( int i = 0; i < loop; i++ ) {
            RunWithOperatorOverload( pX, pV, pF, m, dt, n );
          }
          Console.WriteLine( "{0} [ms]", stopwatch.ElapsedMilliseconds ); // 6.5秒くらい
        }
      }
      {
        var f = CreateRandomVectors( n );
        var v = CreateRandomVectors( n );
        var x = CreateRandomVectors( n );
        fixed ( Vector4* pX = x, pV = v, pF = f ) {
          Console.Write( "RunWithoutOperatorOverload: " );
          var stopwatch = new System.Diagnostics.Stopwatch();
          stopwatch.Start();
          for ( int i = 0; i < loop; i++ ) {
            RunWithoutOperatorOverload( pX, pV, pF, m, dt, n );
          }
          Console.WriteLine( "{0} [ms]", stopwatch.ElapsedMilliseconds ); // 1.3秒くらい
        }
      }

      System.Console.WriteLine( "Press Any key..." );
      System.Console.ReadKey();
    }
  }