Ticket #1216: Index.hs

File Index.hs, 1.9 KB (added by claus, 7 years ago)
Line 
1{-# OPTIONS_GHC -fbang-patterns #-}
2
3import Data.Array.Base
4import GHC.Arr(unsafeIndex,index)
5import Data.Array.IArray
6import Control.Monad.ST
7import Data.Array.ST
8import System.Environment(getArgs)
9
10type Elem = Double
11type Vector = [Elem]
12type Matrix = [Vector]
13
14n :: Num a => a
15n = 40
16
17a :: Matrix
18a = [[if i==j then 1 else 0|i<-[1..n]]|j<-[1..n]]
19
20p :: Vector 
21p = [1..n]
22
23------------------------ array-based, update-in-place code
24
25type VectorA s = STUArray s Int Elem
26type MatrixA s = STUArray s (Int,Int) Elem
27
28{-# INLINE myreadArray #-}
29-- | Read an element from a mutable array
30myreadArray :: (MArray a e m, Ix i) => a i e -> i -> m e
31myreadArray marr i = do
32  (l,u) <- getBounds marr
33  unsafeRead marr (myindex (l,u) i)
34
35{-# INLINE mywriteArray #-}
36-- | Write an element in a mutable array
37mywriteArray :: (MArray a e m, Ix i) => a i e -> i -> e -> m ()
38mywriteArray marr i e = do
39  (l,u) <- getBounds marr
40  unsafeWrite marr (myindex (l,u) i) e
41
42myindex b i = index b i
43-- the following is supposed to be the default implementation of index,
44-- from GHC.Arr
45myindex b i | inRange b i = unsafeIndex b i     
46            | otherwise   = error "Error in array index"
47
48matA :: MatrixA s -> VectorA s -> VectorA s -> ST s (VectorA s)
49(`matA` v) tmp = m `seq` v `seq` tmp `seq` l 1 1 0 
50  where l !i !j !s | i>n = return tmp
51        l  i  j  s | j>n = mywriteArray tmp i s >> l (i+1) 1 0
52        l  i  j  s       = do a<-myreadArray m (i,j)
53                              b<-myreadArray v j
54                              l i (j+1) (s+a*b)
55
56loopA a p q n | n==0 = return q
57loopA a p q n        = do 
58  (a `matA` p) q
59  loopA a p q (n-1)
60
61testA c = runSTUArray (do 
62  aA <- newListArray ((1,1),(n,n)) (concat a)
63  pA <- newListArray (1,n) p
64  qA <- newArray (1,n) 0
65  loopA aA pA qA c
66  )
67
68-----------------------
69main = do
70  (count:_) <- getArgs
71  print $ testA  (read count)