[ruby-dev:49546] Integer#pow の提案

From: "KISHIMOTO, Makoto" <ksmakoto@...4u.or.jp>
Date: 2016-03-30 09:24:27 UTC
List: ruby-dev #49546
きしもとです

暗号の計算などで (a**b)%m を求めたいことがあります。この式のまま書くと、
Bignumがありますのでオーバーフローの問題はありませんが、随時 mod m を
取れば本来必要のない重い計算を省いて軽くできますし、どうせなら組込みで
計算できれば高速化もでき便利だと考え、Integer#pow として実装しました。
(メール末に差分を付けました)
コメント等ありましたらよろしくお願いします。修正の上でプルリクエスト化の
予定です。

実引数の型や値によって例外とする場合などについては Python の組込み関数
pow を参考に( https://docs.python.org/3/library/functions.html#pow )
しました。

可変個引数のメソッドとなっていて、1引数の場合は、そのまま ** を呼び出し
ます。2引数の場合は、Fixnum(long)で高速に計算できる場合は高速に、
そうでない場合は汎用の実装で計算します。

コメントでXXXを付けてある箇所は、プルリクエスト化前に改良したい箇所で、

・Fixnumが31ビットか63ビットか、の振り分け
・最近のコンパイラでは /*NOTREACHED*/ のおまじないが効かないようなので
 その代替はどのようにするのがベストプラクティスでしょうか?

という感じです。

diff --git a/numeric.c b/numeric.c
index 4bb2a32..c5c7afb 100644
--- a/numeric.c
+++ b/numeric.c
@@ -108,6 +108,8 @@ static VALUE fix_mul(VALUE x, VALUE y);
 static VALUE int_pow(long x, unsigned long y);
 static VALUE int_cmp(VALUE x, VALUE y);
 
+static VALUE int_pow2(int argc, VALUE* argv, VALUE x);
+
 static ID id_coerce, id_div, id_divmod;
 #define id_to_i idTo_i
 #define id_eq  idEq
@@ -3550,6 +3552,174 @@ int_cmp(VALUE x, VALUE y)
     }
 }
 
+/* Integer#pow */
+
+static VALUE
+int_pow_tmp1(VALUE x, VALUE y, long mm, int nega_flg)
+{
+    long xx = FIX2LONG(x);
+    long tmp = 1L;
+    long yy;
+
+    for (/*NOP*/; ! FIXNUM_P(y); y = rb_funcall(y, idGTGT, 1, LONG2FIX(1L))) {
+        if (RTEST(int_odd_p(y))) {
+            tmp = (tmp * xx) % mm;
+        }
+        xx = (xx * xx) % mm;
+    }
+    for (yy = FIX2LONG(y); yy; yy >>= 1L) {
+        if (yy & 1L) {
+            tmp = (tmp * xx) % mm;
+        }
+        xx = (xx * xx) % mm;
+    }
+
+    if (nega_flg && tmp) {
+        tmp -= mm;
+    }
+    return LONG2FIX(tmp);
+}
+
+static VALUE
+int_pow_tmp2(VALUE x, VALUE y, long mm, int nega_flg)
+{
+    long tmp = 1L;
+    long yy;
+#ifdef DLONG
+    DLONG const mmm = mm;
+    long xx = FIX2LONG(x);
+
+    for (/*NOP*/; ! FIXNUM_P(y); y = rb_funcall(y, idGTGT, 1, LONG2FIX(1L))) {
+        if (RTEST(int_odd_p(y))) {
+            tmp = ((DLONG)tmp * (DLONG)xx) % mmm;
+        }
+        xx = ((DLONG)xx * (DLONG)xx) % mmm;
+    }
+    for (yy = FIX2LONG(y); yy; yy >>= 1L) {
+        if (yy & 1L) {
+            tmp = ((DLONG)tmp * (DLONG)xx) % mmm;
+        }
+        xx = ((DLONG)xx * (DLONG)xx) % mmm;
+    }
+#else
+    VALUE const m = LONG2FIX(mm);
+    VALUE tmp2 = LONG2FIX(tmp);
+
+    for (/*NOP*/; ! FIXNUM_P(y); y = rb_funcall(y, idGTGT, 1, LONG2FIX(1L))) {
+        if (RTEST(int_odd_p(y))) {
+            tmp2 = rb_fix_mul_fix(tmp2, x);
+            tmp2 = rb_int_modulo(tmp2, m);
+        }
+        x = rb_fix_mul_fix(x, x);
+        x = rb_int_modulo(x, m);
+    }
+    for (yy = FIX2LONG(y); yy; yy >>= 1L) {
+        if (yy & 1L) {
+            tmp2 = rb_fix_mul_fix(tmp2, x);
+            tmp2 = rb_int_modulo(tmp2, m);
+        }
+        x = rb_fix_mul_fix(x, x);
+        x = rb_int_modulo(x, m);
+    }
+
+    tmp = FIX2LONG(tmp2);
+#endif
+    if (nega_flg && tmp) {
+        tmp -= mm;
+    }
+    return LONG2FIX(tmp);
+}
+
+static VALUE
+int_pow_tmp3(VALUE x, VALUE y, VALUE m, int nega_flg)
+{
+    VALUE tmp = LONG2FIX(1L);
+    long yy;
+
+    for (/*NOP*/; ! FIXNUM_P(y); y = rb_funcall(y, idGTGT, 1, LONG2FIX(1L))) {
+        if (RTEST(int_odd_p(y))) {
+            tmp = rb_funcall(tmp, '*', 1, x);
+            tmp = rb_int_modulo(tmp, m);
+        }
+        x = rb_funcall(x, '*', 1, x);
+        x = rb_int_modulo(x, m);
+    }
+    for (yy = FIX2LONG(y); yy; yy >>= 1L) {
+        if (yy & 1L) {
+            tmp = rb_funcall(tmp, '*', 1, x);
+            tmp = rb_int_modulo(tmp, m);
+        }
+        x = rb_funcall(x, '*', 1, x);
+        x = rb_int_modulo(x, m);
+    }
+
+    if (nega_flg && positive_int_p(tmp)) {
+        tmp = rb_funcall(tmp, '-', 1, m);
+    }
+    return tmp;
+}
+
+static VALUE
+int_pow3(VALUE x, VALUE y, VALUE m)
+{
+    int nega_flg = 0;
+    long mm;
+
+    if ( ! (rb_obj_is_kind_of(y, rb_cInteger) && rb_obj_is_kind_of(m, rb_cInteger))) {
+        rb_raise(rb_eTypeError, "Integer#pow() 2nd argument not allowed unless all arguments are integers");
+    }
+    if (negative_int_p(y)) {
+        rb_raise(rb_eRangeError, "Integer#pow() 1st argument cannot be negative when 2nd argument specified");
+    }
+
+    if (negative_int_p(m)) {
+        nega_flg = 1;
+        m = rb_funcall(m, idUMinus, 0);
+    }
+
+    if ( ! positive_int_p(m)) {
+        rb_num_zerodiv();
+    }
+
+    if (FIXNUM_P(m)) {
+        mm = FIX2LONG(m);
+        if (LONG_MAX > 0x7fffffffL) {  // XXX?
+            if (mm <= 0x80000000L) {
+                return int_pow_tmp1(rb_int_modulo(x, m), y, mm, nega_flg);
+            }
+        } else {
+            if (mm <= 0x8000L) {
+                return int_pow_tmp1(rb_int_modulo(x, m), y, mm, nega_flg);
+            }
+        }
+        return int_pow_tmp2(rb_int_modulo(x, m), y, mm, nega_flg);
+    }
+    return int_pow_tmp3(rb_int_modulo(x, m), y, m, nega_flg);
+}
+
+/*
+ * call-seq:
+ *   integer.pow(numeric)  ->  integer ** numeric
+ *   integer.pow(a, m)     ->  (integer `pow` a) mod m
+ */
+
+static VALUE
+int_pow2(int argc, VALUE* argv, VALUE x)
+{
+    VALUE y, m;
+
+    rb_scan_args(argc, argv, "11", &y, &m);
+
+    switch (argc) {
+    case 1:
+        return rb_funcall(x, rb_intern("**"), 1, y);
+    case 2:
+        return int_pow3(x, y, m);
+    }
+
+    return Qnil;  // not reached (XXX?)
+}
+
 /*
  * call-seq:
  *   fix > real  ->  true or false
@@ -4332,6 +4502,7 @@ Init_Numeric(void)
     rb_define_method(rb_cInteger, "truncate", int_to_i, 0);
     rb_define_method(rb_cInteger, "round", int_round, -1);
     rb_define_method(rb_cInteger, "<=>", int_cmp, 1);
+    rb_define_method(rb_cInteger, "pow", int_pow2, -1);
 
     rb_cFixnum = rb_define_class("Fixnum", rb_cInteger);
 

In This Thread

Prev Next