Set.rb patch
From:
"Christoph" <chr_news@...>
Date:
2002-09-03 21:17:06 UTC
List:
ruby-core #423
Hi,
the included patch set.rb.diff resolves bugs in initialize,
flatten(!), eql? and adds sub/super set relations to the newly
imported Set class. The patch set.rb.more.diff makes the
comparisons ==, eql? more type strict similar to
class Array
alias __type_weak_comp ==
def ==(r)
return false unless instance_of?(r.type)
__type_weak_comp(r)
end
end
(I wrongly assumed that this was the current Ruby ``equal way'')
and adds a block option to initialize and contain?.
Generally I think that the Set class is so close to the Hash Class
that it could(should?) be defined (may be together with a MultiSet
class) in Hash.c itself?
/Christoph
Attachments (2)
Set.rb.diff
(9.03 KB, text/x-diff)
--- set.rb.orig 2002-09-02 00:36:49.000000000 +0200
+++ set.rb 2002-09-03 23:05:39.000000000 +0200
@@ -22,14 +22,14 @@
set1 = Set.new ["foo", "bar", "baz"]
- p set1 #=> #<Set: {"baz", "foo", "bar"}>
+ p set1 #=> #<Set: {"baz", "foo", "bar"}>
- p set1.include?("bar") #=> true
+ p set1.include?("bar") #=> true
set1.add("heh")
set1.delete("foo")
- p set1 #=> #<Set: {"heh", "baz", "bar"}>
+ p set1 #=> #<Set: {"heh", "baz", "bar"}>
== Set class
Set implements a collection of unordered values with no duplicates.
@@ -41,11 +41,13 @@
=== Included Modules
Enumerable
+ Comparable
=== Class Methods
--- Set::new(enum = nil)
Creates a new set containing the elements of the given enumerable
- object.
+ object. If a block is given the elements of enum are preprocessed
+ by the given block.
--- Set[*ary]
Creates a new set containing the given objects.
@@ -135,7 +137,11 @@
--- == set
Returns true if two sets are equal. The equality of each couple
- of elements is defined according to Object#eql?.
+ of elements is defined according to Object#==.
+
+--- <=> set
+ Returns -1,1 or 0 if self is a subset, is to equal or is a superset
+ of set. If none of latter is true <=> returns nil.
--- classify { |o| ... }
Classifies the set by the return value of the given block and
@@ -180,17 +186,17 @@
class Set
include Enumerable
+ include Comparable
def self.[](*ary)
new(ary)
end
def initialize(enum = nil)
- @hash = {}
-
- if enum
+ @hash = Hash.new
+ unless enum.nil?
enum.is_a?(Enumerable) or raise ArgumentError, "value must be enumerable"
- enum.each { |o| @hash[o] = true }
+ enum.each { |o| @hash[o] = true }
end
end
@@ -225,34 +231,33 @@
@hash.keys
end
- def _flatten(set, ids = type.new, result = type.new)
- setid = set.id
-
- ids.include?(setid) and raise ArgumentError, "tried to flatten recursive #{type.name}"
-
- ids.add(setid)
-
- set.each { |o|
- if o.is_a?(type)
- _flatten(o, ids, result)
+ protected
+ def _flatten(set, seen)
+ set.each {|e|
+ if Set === e
+ if seen.include?(e_id = e.id)
+ raise ArgumentError, "tried to flatten recursive Set"
+ end
+ seen.add e_id
+ _flatten(e,seen)
+ seen.delete e_id
else
- result.add(o)
+ add(e)
end
}
-
- result
+ self
end
- private :_flatten
-
+
+ public
def flatten
- _flatten(self)
+ type.new._flatten(self, Set.new)
end
-
+
def flatten!
- ids = type.new
- replace(_flatten(self, ids))
-
- ids.size == 1 ? nil : self
+ if any? {|e| Set === e }
+ @hash.replace flatten.instance_eval {@hash }
+ self
+ end
end
def include?(o)
@@ -261,9 +266,10 @@
alias member? include?
def contain?(enum)
- enum.is_a?(Enumerable) or raise ArgumentError, "value must be enumerable"
- enum.each { |o| include?(o) or return false }
- true
+ Enumerable === enum or raise ArgumentError, "value must be enumerable"
+ unless block_given?
+ enum.all? { |o| include?(o) }
+ end
end
def each
@@ -309,7 +315,7 @@
enum.each { |o| n.add(o) }
n
end
- alias | + ##
+ alias | + ##
def -(enum)
enum.is_a?(Enumerable) or raise ArgumentError, "value must be enumerable"
@@ -334,20 +340,32 @@
def ==(set)
equal?(set) and return true
-
- set.is_a?(type) && size == set.size or return false
-
- set.each { |o| include?(o) or return false }
-
- true
+ unless Set === set and size == set.size
+ false
+ else
+ all? { |e| set.include?(e) }
+ end
end
def hash
@hash.hash
end
- def eql?(o)
- @hash == o.hash
+ alias eql? ==
+
+ def <=>(o)
+ Set === o or raise ArgumentError, "value #{o} must be a Set"
+ equal?(o) and return 0
+ if size < o.size
+ @hash.each_key {|e| o.include?(e) or return nil }
+ -1
+ elsif @hash.size == o.size
+ @hash.each_key {|e| o.include?(e) or return nil }
+ 0
+ else
+ o.each {|e| include?(e) or return nil }
+ 1
+ end
end
def classify
@@ -365,23 +383,21 @@
if func.arity == 2
require 'tsort'
- class << dig = {}
- include TSort
-
- alias tsort_each_node each_key
- def tsort_each_child(node, &block)
- fetch(node).each(&block)
- end
+ class << dig = Hash.new.extend(TSort)
+ alias tsort_each_node each_key
+ def tsort_each_child(node, &block)
+ fetch(node).each(&block)
+ end
end
each { |u|
- dig[u] = a = []
- each{ |v| func.call(u, v) and a << v }
+ dig[u] = a = []
+ each { |v| func.call(u, v) and a << v }
}
- set = type.new()
+ set = type.new
dig.each_strongly_connected_component { |css|
- set.add(Set.new(css))
+ set.add(Set.new(css))
}
set
else
@@ -411,13 +427,13 @@
pp.nest(1) {
first = true
each { |o|
- if first
- first = false
- else
- pp.text ","
- pp.breakable
- end
- pp.pp o
+ if first
+ first = false
+ else
+ pp.text ","
+ pp.breakable
+ end
+ pp.pp o
}
}
pp.text "}>"
@@ -436,9 +452,9 @@
class TC_Set < Test::Unit::TestCase
def test_aref
assert_nothing_raised {
- Set[]
- Set[nil]
- Set[1,2,3]
+ Set[]
+ Set[nil]
+ Set[1,2,3]
}
assert_equal(0, Set[].size)
@@ -452,18 +468,18 @@
def test_s_new
assert_nothing_raised {
- Set.new()
- Set.new(nil)
- Set.new([])
- Set.new([1,2])
- Set.new('a'..'c')
- Set.new('XYZ')
+ Set.new()
+ Set.new(nil)
+ Set.new([])
+ Set.new([1,2])
+ Set.new('a'..'c')
+ Set.new('XYZ')
}
assert_raises(ArgumentError) {
- Set.new(1)
+ Set.new(1)
}
assert_raises(ArgumentError) {
- Set.new(1,2)
+ Set.new(1,2)
}
assert_equal(0, Set.new().size)
@@ -527,17 +543,17 @@
def test_flatten
set1 = Set[
- 1,
- Set[
- 5,
- Set[7,
- Set[0]
- ],
- Set[6,2],
- 1
- ],
- 3,
- Set[3,4]
+ 1,
+ Set[
+ 5,
+ Set[7,
+ Set[0]
+ ],
+ Set[6,2],
+ 1
+ ],
+ 3,
+ Set[3,4]
]
set2 = set1.flatten
@@ -575,11 +591,11 @@
set = Set[1,2,3]
assert_raises(ArgumentError) {
- set.contain?()
+ set.contain?()
}
assert_raises(ArgumentError) {
- set.contain?(2)
+ set.contain?(2)
}
assert_equal(true, set.contain?([]))
@@ -594,15 +610,15 @@
set = Set.new(ary)
assert_raises(LocalJumpError) {
- set.each
+ set.each
}
assert_nothing_raised {
- set.each { |o|
- ary.delete(o) or raise "unexpected element: #{o}"
- }
+ set.each { |o|
+ ary.delete(o) or raise "unexpected element: #{o}"
+ }
- ary.empty? or raise "forgotten elements: #{ary.join(', ')}"
+ ary.empty? or raise "forgotten elements: #{ary.join(', ')}"
}
end
@@ -701,13 +717,32 @@
assert_equal(set1, set1)
assert_equal(set1, set2)
assert_not_equal(Set[1], [1])
+ aset = Class.new(Set)["a","b"]
+
+ _a = Class.new(Set)["a","b"]
+ a = Set["a","b",_a]
+ _a = _a.add(_a.clone)
+
+ assert_equal _a, a
+ assert_equal a,_a
+ assert_equal a, a.clone
+ assert_equal _a.clone, _a
end
# def test_hash
# end
- # def test_eql?
- # end
+ class EqlClass
+ def hash
+ super % 11
+ end
+ end
+
+ def test_eql?
+ a = EqlClass.new
+ b = EqlClass.new until b.hash == a.hash
+ assert !(Set[a].eql?(Set[b]))
+ end
def test_classify
set = Set.new(1..10)
@@ -740,17 +775,17 @@
assert_equal(set.size, n)
assert_equal(set, ret.flatten)
ret.each { |s|
- if s.include?(0)
- assert_equal(Set[0,1], s)
- elsif s.include?(3)
- assert_equal(Set[3,4,5], s)
- elsif s.include?(7)
- assert_equal(Set[7], s)
- elsif s.include?(9)
- assert_equal(Set[9,10,11], s)
- else
- raise "unexpected group: #{s.inspect}"
- end
+ if s.include?(0)
+ assert_equal(Set[0,1], s)
+ elsif s.include?(3)
+ assert_equal(Set[3,4,5], s)
+ elsif s.include?(7)
+ assert_equal(Set[7], s)
+ elsif s.include?(9)
+ assert_equal(Set[9,10,11], s)
+ else
+ raise "unexpected group: #{s.inspect}"
+ end
}
end
@@ -771,6 +806,34 @@
# def test_pretty_print_cycled
# end
+
+ def test_fail_false_initialize
+ assert_raises(ArgumentError) {
+ Set.new (false)
+ }
+ end
+
+ def test_more_flatten
+ empty = Set[]
+ set = Set[Set[empty,"a"],Set[empty,"b"]]
+ assert_nothing_raised { set.flatten }
+ set1 = empty.merge Set["no_more",set]
+ assert_nil Set.new(0..31).flatten!
+ x = Set[Set[],Set[1,2]].flatten!
+ y = Set[1,2]
+ assert_equal x,y
+ end
+
+ def test_compare
+ a= Set[1,2]
+ b = Set[1,2,3]
+ c = Set[1,3]
+
+ assert_equal (a <=> b), -1
+ assert_equal (b <=> a), 1
+ assert_equal (b <=> b.clone), 0
+ assert_nil a <=> c
+ end
end
Test::Unit::UI::Console::TestRunner.run(TC_Set)
Set.rb.more.diff
(9.5 KB, text/x-diff)
--- set.rb.orig 2002-09-02 00:36:49.000000000 +0200
+++ set.rb 2002-09-03 22:13:13.000000000 +0200
@@ -22,14 +22,14 @@
set1 = Set.new ["foo", "bar", "baz"]
- p set1 #=> #<Set: {"baz", "foo", "bar"}>
+ p set1 #=> #<Set: {"baz", "foo", "bar"}>
- p set1.include?("bar") #=> true
+ p set1.include?("bar") #=> true
set1.add("heh")
set1.delete("foo")
- p set1 #=> #<Set: {"heh", "baz", "bar"}>
+ p set1 #=> #<Set: {"heh", "baz", "bar"}>
== Set class
Set implements a collection of unordered values with no duplicates.
@@ -41,11 +41,13 @@
=== Included Modules
Enumerable
+ Comparable
=== Class Methods
--- Set::new(enum = nil)
Creates a new set containing the elements of the given enumerable
- object.
+ object. If a block is given the elements of enum are preprocessed
+ by the given block.
--- Set[*ary]
Creates a new set containing the given objects.
@@ -135,7 +137,11 @@
--- == set
Returns true if two sets are equal. The equality of each couple
- of elements is defined according to Object#eql?.
+ of elements is defined according to Object#==.
+
+--- <=> set
+ Returns -1,1 or 0 if self is a subset, is to equal or is a superset
+ of set. If none of latter is true <=> returns nil.
--- classify { |o| ... }
Classifies the set by the return value of the given block and
@@ -180,17 +186,21 @@
class Set
include Enumerable
+ include Comparable
def self.[](*ary)
new(ary)
end
def initialize(enum = nil)
- @hash = {}
-
- if enum
+ @hash = Hash.new
+ unless enum.nil?
enum.is_a?(Enumerable) or raise ArgumentError, "value must be enumerable"
- enum.each { |o| @hash[o] = true }
+ unless block_given?
+ enum.each { |o| @hash[o] = true }
+ else
+ enum.each { |o| @hash[yield(o)] = true }
+ end
end
end
@@ -225,34 +235,33 @@
@hash.keys
end
- def _flatten(set, ids = type.new, result = type.new)
- setid = set.id
-
- ids.include?(setid) and raise ArgumentError, "tried to flatten recursive #{type.name}"
-
- ids.add(setid)
-
- set.each { |o|
- if o.is_a?(type)
- _flatten(o, ids, result)
+ protected
+ def _flatten(set, seen)
+ set.each {|e|
+ if Set === e
+ if seen.include?(e_id = e.id)
+ raise ArgumentError, "tried to flatten recursive Set"
+ end
+ seen.add e_id
+ _flatten(e,seen)
+ seen.delete e_id
else
- result.add(o)
+ add(e)
end
}
-
- result
+ self
end
- private :_flatten
-
+
+ public
def flatten
- _flatten(self)
+ type.new._flatten(self, Set.new)
end
-
+
def flatten!
- ids = type.new
- replace(_flatten(self, ids))
-
- ids.size == 1 ? nil : self
+ if any? {|e| Set === e }
+ @hash.replace flatten.instance_eval {@hash }
+ self
+ end
end
def include?(o)
@@ -261,9 +270,12 @@
alias member? include?
def contain?(enum)
- enum.is_a?(Enumerable) or raise ArgumentError, "value must be enumerable"
- enum.each { |o| include?(o) or return false }
- true
+ Enumerable === enum or raise ArgumentError, "value must be enumerable"
+ unless block_given?
+ enum.all? { |o| include?(o) }
+ else
+ enum.all? { |o| include?(yield(o)) }
+ end
end
def each
@@ -309,7 +321,7 @@
enum.each { |o| n.add(o) }
n
end
- alias | + ##
+ alias | + ##
def -(enum)
enum.is_a?(Enumerable) or raise ArgumentError, "value must be enumerable"
@@ -334,20 +346,32 @@
def ==(set)
equal?(set) and return true
-
- set.is_a?(type) && size == set.size or return false
-
- set.each { |o| include?(o) or return false }
-
- true
+ unless set.instance_of?(type) and size == set.size
+ false
+ else
+ all? { |e| set.include?(e) }
+ end
end
def hash
@hash.hash
end
- def eql?(o)
- @hash == o.hash
+ alias eql? ==
+
+ def <=>(o)
+ return or raise ArgumentError, "value #{o} must be a #{type}"
+ equal?(o) and return 0
+ if size < o.size
+ @hash.each_key {|e| o.include?(e) or return nil }
+ -1
+ elsif @hash.size == o.size
+ @hash.each_key {|e| o.include?(e) or return nil }
+ 0
+ else
+ o.each {|e| include?(e) or return nil }
+ 1
+ end
end
def classify
@@ -365,23 +389,21 @@
if func.arity == 2
require 'tsort'
- class << dig = {}
- include TSort
-
- alias tsort_each_node each_key
- def tsort_each_child(node, &block)
- fetch(node).each(&block)
- end
+ class << dig = Hash.new.extend(TSort)
+ alias tsort_each_node each_key
+ def tsort_each_child(node, &block)
+ fetch(node).each(&block)
+ end
end
each { |u|
- dig[u] = a = []
- each{ |v| func.call(u, v) and a << v }
+ dig[u] = a = []
+ each { |v| func.call(u, v) and a << v }
}
- set = type.new()
+ set = type.new
dig.each_strongly_connected_component { |css|
- set.add(Set.new(css))
+ set.add(Set.new(css))
}
set
else
@@ -411,13 +433,13 @@
pp.nest(1) {
first = true
each { |o|
- if first
- first = false
- else
- pp.text ","
- pp.breakable
- end
- pp.pp o
+ if first
+ first = false
+ else
+ pp.text ","
+ pp.breakable
+ end
+ pp.pp o
}
}
pp.text "}>"
@@ -436,9 +458,9 @@
class TC_Set < Test::Unit::TestCase
def test_aref
assert_nothing_raised {
- Set[]
- Set[nil]
- Set[1,2,3]
+ Set[]
+ Set[nil]
+ Set[1,2,3]
}
assert_equal(0, Set[].size)
@@ -452,18 +474,18 @@
def test_s_new
assert_nothing_raised {
- Set.new()
- Set.new(nil)
- Set.new([])
- Set.new([1,2])
- Set.new('a'..'c')
- Set.new('XYZ')
+ Set.new()
+ Set.new(nil)
+ Set.new([])
+ Set.new([1,2])
+ Set.new('a'..'c')
+ Set.new('XYZ')
}
assert_raises(ArgumentError) {
- Set.new(1)
+ Set.new(1)
}
assert_raises(ArgumentError) {
- Set.new(1,2)
+ Set.new(1,2)
}
assert_equal(0, Set.new().size)
@@ -527,17 +549,17 @@
def test_flatten
set1 = Set[
- 1,
- Set[
- 5,
- Set[7,
- Set[0]
- ],
- Set[6,2],
- 1
- ],
- 3,
- Set[3,4]
+ 1,
+ Set[
+ 5,
+ Set[7,
+ Set[0]
+ ],
+ Set[6,2],
+ 1
+ ],
+ 3,
+ Set[3,4]
]
set2 = set1.flatten
@@ -575,11 +597,11 @@
set = Set[1,2,3]
assert_raises(ArgumentError) {
- set.contain?()
+ set.contain?()
}
assert_raises(ArgumentError) {
- set.contain?(2)
+ set.contain?(2)
}
assert_equal(true, set.contain?([]))
@@ -594,15 +616,15 @@
set = Set.new(ary)
assert_raises(LocalJumpError) {
- set.each
+ set.each
}
assert_nothing_raised {
- set.each { |o|
- ary.delete(o) or raise "unexpected element: #{o}"
- }
+ set.each { |o|
+ ary.delete(o) or raise "unexpected element: #{o}"
+ }
- ary.empty? or raise "forgotten elements: #{ary.join(', ')}"
+ ary.empty? or raise "forgotten elements: #{ary.join(', ')}"
}
end
@@ -694,20 +716,38 @@
assert_equal(Set[2,4], ret)
end
- def test_eq
+ def test_eq_and_eql?
set1 = Set[2,3,1]
set2 = Set[1,2,3]
assert_equal(set1, set1)
assert_equal(set1, set2)
assert_not_equal(Set[1], [1])
+ aset = Class.new(Set)["a","b"]
+
+ _a = Class.new(Set)["a","b"]
+ a = Set["a","b",_a]
+ _a = _a.add(_a.clone)
+
+ assert_not_equal _a, a
+ assert_not_equal a,_a
+ assert_equal a, a.clone
+ assert_equal _a.clone, _a
end
# def test_hash
# end
- # def test_eql?
- # end
+ def test_eql?
+ _a = Class.new(Set)["a","b"]
+ a = Set["a","b",[_a]]
+ _a = _a.add([_a.clone])
+
+ assert !(a.eql? _a)
+ assert !(_a.eql? a)
+ assert (a.clone.eql? a)
+ assert (_a.eql? _a.clone)
+ end
def test_classify
set = Set.new(1..10)
@@ -740,17 +780,17 @@
assert_equal(set.size, n)
assert_equal(set, ret.flatten)
ret.each { |s|
- if s.include?(0)
- assert_equal(Set[0,1], s)
- elsif s.include?(3)
- assert_equal(Set[3,4,5], s)
- elsif s.include?(7)
- assert_equal(Set[7], s)
- elsif s.include?(9)
- assert_equal(Set[9,10,11], s)
- else
- raise "unexpected group: #{s.inspect}"
- end
+ if s.include?(0)
+ assert_equal(Set[0,1], s)
+ elsif s.include?(3)
+ assert_equal(Set[3,4,5], s)
+ elsif s.include?(7)
+ assert_equal(Set[7], s)
+ elsif s.include?(9)
+ assert_equal(Set[9,10,11], s)
+ else
+ raise "unexpected group: #{s.inspect}"
+ end
}
end
@@ -771,6 +811,40 @@
# def test_pretty_print_cycled
# end
+
+ def test_fail_false_initialize
+ assert_raises(ArgumentError) {
+ Set.new (false)
+ }
+ end
+
+ def test_block_initialize
+ a = Set[*(-4..0)]
+ b = Set.new(0..4) {|e| -e }
+ assert_equal a, b
+ end
+
+ def test_more_flatten
+ empty = Set[]
+ set = Set[Set[empty,"a"],Set[empty,"b"]]
+ assert_nothing_raised { set.flatten }
+ set1 = empty.merge Set["no_more",set]
+ assert_nil Set.new(0..31).flatten!
+ x = Set[Set[],Set[1,2]].flatten!
+ y = Set[1,2]
+ assert_equal x,y
+ end
+
+ def test_compare
+ a= Set[1,2]
+ b = Set[1,2,3]
+ c = Set[1,3]
+
+ assert_equal (a <=> b), -1
+ assert_equal (b <=> a), 1
+ assert_equal (b <=> b.clone), 0
+ assert_nil a <=> c
+ end
end
Test::Unit::UI::Console::TestRunner.run(TC_Set)