下面的实现提供了一个类似 Python 的生成器。
请注意,在下面的代码中有一个名为_yield 的函数,因为yield 在Scala 中已经是一个关键字,顺便说一下,它与您从Python 中知道的yield 没有任何关系。
import scala.annotation.tailrec
import scala.collection.immutable.Stream
import scala.util.continuations._
object Generators {
sealed trait Trampoline[+T]
case object Done extends Trampoline[Nothing]
case class Continue[T](result: T, next: Unit => Trampoline[T]) extends Trampoline[T]
class Generator[T](var cont: Unit => Trampoline[T]) extends Iterator[T] {
def next: T = {
cont() match {
case Continue(r, nextCont) => cont = nextCont; r
case _ => sys.error("Generator exhausted")
}
}
def hasNext = cont() != Done
}
type Gen[T] = cps[Trampoline[T]]
def generator[T](body: => Unit @Gen[T]): Generator[T] = {
new Generator((Unit) => reset { body; Done })
}
def _yield[T](t: T): Unit @Gen[T] =
shift { (cont: Unit => Trampoline[T]) => Continue(t, cont) }
}
object TestCase {
import Generators._
def sectors = generator {
def tailrec(seq: Seq[String]): Unit @Gen[String] = {
if (!seq.isEmpty) {
_yield(seq.head)
tailrec(seq.tail)
}
}
val list: Seq[String] = List("Financials", "Materials", "Technology", "Utilities")
tailrec(list)
}
def main(args: Array[String]): Unit = {
for (s <- sectors) { println(s) }
}
}
它工作得很好,包括 for 循环的典型用法。
警告:我们需要记住 Python 和 Scala 在实现延续的方式上有所不同。下面我们将看到生成器通常如何在 Python 中使用,并与我们在 Scala 中使用它们的方式进行比较。然后,我们将了解为什么它需要在 Scala 中如此。
如果你习惯用 Python 编写代码,你可能使用过这样的生成器:
// This is Scala code that does not compile :(
// This code naively tries to mimic the way generators are used in Python
def myGenerator = generator {
val list: Seq[String] = List("Financials", "Materials", "Technology", "Utilities")
list foreach {s => _yield(s)}
}
上面的代码无法编译。跳过所有复杂的理论方面,解释是:它无法编译,因为 “for 循环的类型” 与作为延续的一部分所涉及的类型不匹配。恐怕这种解释是完全失败的。让我再试一次:
如果您编写了如下所示的代码,则可以正常编译:
def myGenerator = generator {
_yield("Financials")
_yield("Materials")
_yield("Technology")
_yield("Utilities")
}
此代码可以编译,因为生成器可以在yields 的序列中分解,在这种情况下,yield 匹配延续中涉及的类型。更准确地说,可以将代码分解为链式块,其中每个块以yield 结尾。为了清楚起见,我们可以认为yields的序列可以这样表示:
{ some code here; _yield("Financials")
{ some other code here; _yield("Materials")
{ eventually even some more code here; _yield("Technology")
{ ok, fine, youve got the idea, right?; _yield("Utilities") }}}}
再次,无需深入复杂的理论,重点是,在yield 之后,您需要提供另一个以yield 结尾的块,否则关闭链。这就是我们在上面的伪代码中所做的:在yield 之后,我们打开另一个块,它依次以yield 结尾,然后是另一个yield,又以另一个yield 结尾,等等在。显然,这件事必须在某个时候结束。那么我们唯一可以做的就是关闭整个链条。
好的。但是...我们如何才能yield 多条信息?答案有点晦涩,但知道答案后就很有意义了:我们需要使用尾递归,并且块的最后一条语句必须是yield。
def myGenerator = generator {
def tailrec(seq: Seq[String]): Unit @Gen[String] = {
if (!seq.isEmpty) {
_yield(seq.head)
tailrec(seq.tail)
}
}
val list = List("Financials", "Materials", "Technology", "Utilities")
tailrec(list)
}
让我们分析一下这里发生了什么:
我们的生成器函数myGenerator 包含一些获取生成信息的逻辑。在这个例子中,我们简单地使用了一个字符串序列。
我们的生成器函数myGenerator 调用一个递归函数,该函数负责yield-ing 从我们的字符串序列中获取的多条信息。
递归函数必须在使用前声明,否则编译器崩溃。
递归函数tailrec提供了我们需要的尾递归。
这里的经验法则很简单:用递归函数替换 for 循环,如上所示。
请注意,tailrec 只是我们找到的一个方便的名称,为了澄清起见。特别是,tailrec 不需要是我们的生成器函数的最后一条语句;不必要。唯一的限制是您必须提供与yield 类型匹配的块序列,如下所示:
def myGenerator = generator {
def tailrec(seq: Seq[String]): Unit @Gen[String] = {
if (!seq.isEmpty) {
_yield(seq.head)
tailrec(seq.tail)
}
}
_yield("Before the first call")
_yield("OK... not yet...")
_yield("Ready... steady... go")
val list = List("Financials", "Materials", "Technology", "Utilities")
tailrec(list)
_yield("done")
_yield("long life and prosperity")
}
更进一步,您必须想象现实生活中的应用程序是什么样子,尤其是在您使用多个生成器的情况下。如果您找到一种方法来标准化您的生成器,这将是一个好主意,该方法可以证明在大多数情况下都很方便。
让我们看看下面的例子。我们有三个生成器:sectors、industries 和 companies。为简洁起见,仅完整显示了sectors。该生成器使用tailrec 函数,如上所示。这里的技巧是相同的tailrec 函数也被其他生成器使用。我们所要做的就是提供一个不同的body 函数。
type GenP = (NodeSeq, NodeSeq, NodeSeq)
type GenR = immutable.Map[String, String]
def tailrec(p: GenP)(body: GenP => GenR): Unit @Gen[GenR] = {
val (stats, rows, header) = p
if (!stats.isEmpty && !rows.isEmpty) {
val heads: GenP = (stats.head, rows.head, header)
val tails: GenP = (stats.tail, rows.tail, header)
_yield(body(heads))
// tail recursion
tailrec(tails)(body)
}
}
def sectors = generator[GenR] {
def body(p: GenP): GenR = {
// unpack arguments
val stat, row, header = p
// obtain name and url
val name = (row \ "a").text
val url = (row \ "a" \ "@href").text
// create map and populate fields: name and url
var m = new scala.collection.mutable.HashMap[String, String]
m.put("name", name)
m.put("url", url)
// populate other fields
(header, stat).zipped.foreach { (k, v) => m.put(k.text, v.text) }
// returns a map
m
}
val root : scala.xml.NodeSeq = cache.loadHTML5(urlSectors) // obtain entire page
val header: scala.xml.NodeSeq = ... // code is omitted
val stats : scala.xml.NodeSeq = ... // code is omitted
val rows : scala.xml.NodeSeq = ... // code is omitted
// tail recursion
tailrec((stats, rows, header))(body)
}
def industries(sector: String) = generator[GenR] {
def body(p: GenP): GenR = {
//++ similar to 'body' demonstrated in "sectors"
// returns a map
m
}
//++ obtain NodeSeq variables, like demonstrated in "sectors"
// tail recursion
tailrec((stats, rows, header))(body)
}
def companies(sector: String) = generator[GenR] {
def body(p: GenP): GenR = {
//++ similar to 'body' demonstrated in "sectors"
// returns a map
m
}
//++ obtain NodeSeq variables, like demonstrated in "sectors"
// tail recursion
tailrec((stats, rows, header))(body)
}