~dieggsy/chicken-genann

chicken-genann/example4.scm -rwxr-xr-x 2.7 KiB
2e21e197dieggsy Update source url 2 years ago
                                                                                
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
#!/usr/bin/csi -s
(import genann
        (only chicken.io read-line)
        (only chicken.string string-split)
        (only chicken.format printf)
        (only srfi-4 f64vector-ref list->f64vector)
        (only srfi-1 split-at))

(printf "Genann example 4.~n")
(printf "Train an ANN on the IRIS dataset using backpropagation.~n")

(define iris-data "example/iris.data")

(define class-names '(("Iris-setosa" . #f64(1 0 0))
                      ("Iris-versicolor" . #f64(0 1 0))
                      ("Iris-virginica" . #f64(0 0 1))))

(define (load-data)
 (call-with-input-file iris-data
   (lambda (p)
     (let loop ((line (read-line p))
                (samples 0)
                (inputs '())
                (class '()))
       (if (eof-object? line)
           (values samples (list->vector inputs) (list->vector class))
           (let*-values (((split) (string-split line ","))
                         ((ilist clist) (split-at split 4)))
             (loop
              (read-line p)
              (add1 samples)
              (cons (list->f64vector (map string->number ilist))
                    inputs)
              (cons (alist-ref (car clist) class-names string=?) class))))))))

;; Load data from file
(define-values (samples input class) (load-data))

(printf "Loaded ~a data points from ~a~n" samples iris-data)

;; 4 inputs. 1 hidden layer(s) of 4 neurons. 3 outputs (1 per class)
(define ann (make-genann 4 1 4 3))

(define loops 5000)

(printf "Training for ~a loops over data.\n" loops)

;; Train the network with backpropagation.
(do ((i 0 (add1 i)))
    ((= i loops))
  (do ((j 0 (add1 j)))
      ((= j samples))
    (genann-train ann (vector-ref input j) (vector-ref class j) .01)))

(define ~ f64vector-ref)


(let loop ((j 0)
           (correct 0))
  (if (= j samples)
      (printf "~a/~a correct (~a%)." correct samples
              (round (* 100 (/ correct samples 1.0))))
      (let ((guess (genann-run ann (vector-ref input j))))
        (cond ((= 1.0 (~ (vector-ref class j) 0))
               (if (and (> (~ guess 0) (~ guess 1))
                        (> (~ guess 0) (~ guess 2)))
                   (loop (add1 j) (add1 correct))
                   (loop (add1 j) correct)))
              ((= 1.0 (~ (vector-ref class j) 1))
               (if (and (> (~ guess 1) (~ guess 0))
                        (> (~ guess 1) (~ guess 2)))
                   (loop (add1 j) (add1 correct))
                   (loop (add1 j) correct)))
              ((= 1.0 (~ (vector-ref class j) 2))
               (if (and (> (~ guess 2) (~ guess 0))
                        (> (~ guess 2) (~ guess 1)))
                   (loop (add1 j) (add1 correct))
                   (loop (add1 j) correct)))
              (else
               (printf "Logic error.~n")
               (exit 1))))))