{-# LANGUAGE BangPatterns #-}
{-# OPTIONS_GHC -funbox-strict-fields #-}

module IR.Constraint.UnionFind (
  Point,
  fresh,
  union,
  equivalent,
  redundant,
  get,
  set,
  modify,
) where

{- This is based on the following implementations:

  - https://hackage.haskell.org/package/union-find-0.2/docs/src/Data-UnionFind-IO.html
  - http://yann.regis-gianas.org/public/mini/code_UnionFind.html

It seems like the OCaml one came first, but I am not sure.

Compared to the Haskell implementation, the major changes here include:

  1. No more reallocating PointInfo when changing the weight
  2. Using the strict modifyIORef

-}

import Control.Monad (when)
import Control.Monad.Trans (
  MonadIO,
  liftIO,
 )
import Data.IORef (
  IORef,
  modifyIORef',
  newIORef,
  readIORef,
  writeIORef,
 )
import Data.Word (Word32)


-- POINT

newtype Point a
  = Pt (IORef (PointInfo a))
  deriving (Point a -> Point a -> Bool
(Point a -> Point a -> Bool)
-> (Point a -> Point a -> Bool) -> Eq (Point a)
forall a. Point a -> Point a -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Point a -> Point a -> Bool
$c/= :: forall a. Point a -> Point a -> Bool
== :: Point a -> Point a -> Bool
$c== :: forall a. Point a -> Point a -> Bool
Eq)


data PointInfo a
  = Info {-# UNPACK #-} !(IORef Word32) {-# UNPACK #-} !(IORef a)
  | Link {-# UNPACK #-} !(Point a)


-- HELPERS

fresh :: MonadIO m => a -> m (Point a)
fresh :: a -> m (Point a)
fresh a
value = do
  IORef Word32
weight <- IO (IORef Word32) -> m (IORef Word32)
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO (IORef Word32) -> m (IORef Word32))
-> IO (IORef Word32) -> m (IORef Word32)
forall a b. (a -> b) -> a -> b
$ Word32 -> IO (IORef Word32)
forall a. a -> IO (IORef a)
newIORef Word32
1
  IORef a
desc <- IO (IORef a) -> m (IORef a)
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO (IORef a) -> m (IORef a)) -> IO (IORef a) -> m (IORef a)
forall a b. (a -> b) -> a -> b
$ a -> IO (IORef a)
forall a. a -> IO (IORef a)
newIORef a
value
  IORef (PointInfo a)
link <- IO (IORef (PointInfo a)) -> m (IORef (PointInfo a))
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO (IORef (PointInfo a)) -> m (IORef (PointInfo a)))
-> IO (IORef (PointInfo a)) -> m (IORef (PointInfo a))
forall a b. (a -> b) -> a -> b
$ PointInfo a -> IO (IORef (PointInfo a))
forall a. a -> IO (IORef a)
newIORef (IORef Word32 -> IORef a -> PointInfo a
forall a. IORef Word32 -> IORef a -> PointInfo a
Info IORef Word32
weight IORef a
desc)
  Point a -> m (Point a)
forall (m :: * -> *) a. Monad m => a -> m a
return (IORef (PointInfo a) -> Point a
forall a. IORef (PointInfo a) -> Point a
Pt IORef (PointInfo a)
link)


repr :: MonadIO m => Point a -> m (Point a)
repr :: Point a -> m (Point a)
repr point :: Point a
point@(Pt IORef (PointInfo a)
ref) = do
  PointInfo a
pInfo <- IO (PointInfo a) -> m (PointInfo a)
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO (PointInfo a) -> m (PointInfo a))
-> IO (PointInfo a) -> m (PointInfo a)
forall a b. (a -> b) -> a -> b
$ IORef (PointInfo a) -> IO (PointInfo a)
forall a. IORef a -> IO a
readIORef IORef (PointInfo a)
ref
  case PointInfo a
pInfo of
    Info IORef Word32
_ IORef a
_ -> Point a -> m (Point a)
forall (m :: * -> *) a. Monad m => a -> m a
return Point a
point
    Link point1 :: Point a
point1@(Pt IORef (PointInfo a)
ref1) -> do
      Point a
point2 <- Point a -> m (Point a)
forall (m :: * -> *) a. MonadIO m => Point a -> m (Point a)
repr Point a
point1
      Bool -> m () -> m ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Point a
point2 Point a -> Point a -> Bool
forall a. Eq a => a -> a -> Bool
/= Point a
point1) (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$ do
        PointInfo a
pInfo1 <- IO (PointInfo a) -> m (PointInfo a)
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO (PointInfo a) -> m (PointInfo a))
-> IO (PointInfo a) -> m (PointInfo a)
forall a b. (a -> b) -> a -> b
$ IORef (PointInfo a) -> IO (PointInfo a)
forall a. IORef a -> IO a
readIORef IORef (PointInfo a)
ref1
        IO () -> m ()
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> m ()) -> IO () -> m ()
forall a b. (a -> b) -> a -> b
$ IORef (PointInfo a) -> PointInfo a -> IO ()
forall a. IORef a -> a -> IO ()
writeIORef IORef (PointInfo a)
ref PointInfo a
pInfo1
      Point a -> m (Point a)
forall (m :: * -> *) a. Monad m => a -> m a
return Point a
point2


get :: MonadIO m => Point a -> m a
get :: Point a -> m a
get point :: Point a
point@(Pt IORef (PointInfo a)
ref) = do
  PointInfo a
pInfo <- IO (PointInfo a) -> m (PointInfo a)
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO (PointInfo a) -> m (PointInfo a))
-> IO (PointInfo a) -> m (PointInfo a)
forall a b. (a -> b) -> a -> b
$ IORef (PointInfo a) -> IO (PointInfo a)
forall a. IORef a -> IO a
readIORef IORef (PointInfo a)
ref
  case PointInfo a
pInfo of
    Info IORef Word32
_ IORef a
descRef -> IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO a -> m a) -> IO a -> m a
forall a b. (a -> b) -> a -> b
$ IORef a -> IO a
forall a. IORef a -> IO a
readIORef IORef a
descRef
    Link (Pt IORef (PointInfo a)
ref1) -> do
      PointInfo a
link' <- IO (PointInfo a) -> m (PointInfo a)
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO (PointInfo a) -> m (PointInfo a))
-> IO (PointInfo a) -> m (PointInfo a)
forall a b. (a -> b) -> a -> b
$ IORef (PointInfo a) -> IO (PointInfo a)
forall a. IORef a -> IO a
readIORef IORef (PointInfo a)
ref1
      case PointInfo a
link' of
        Info IORef Word32
_ IORef a
descRef -> IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO a -> m a) -> IO a -> m a
forall a b. (a -> b) -> a -> b
$ IORef a -> IO a
forall a. IORef a -> IO a
readIORef IORef a
descRef
        Link Point a
_ -> Point a -> m a
forall (m :: * -> *) a. MonadIO m => Point a -> m a
get (Point a -> m a) -> m (Point a) -> m a
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Point a -> m (Point a)
forall (m :: * -> *) a. MonadIO m => Point a -> m (Point a)
repr Point a
point


set :: MonadIO m => Point a -> a -> m ()
set :: Point a -> a -> m ()
set point :: Point a
point@(Pt IORef (PointInfo a)
ref) a
newDesc = do
  PointInfo a
pInfo <- IO (PointInfo a) -> m (PointInfo a)
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO (PointInfo a) -> m (PointInfo a))
-> IO (PointInfo a) -> m (PointInfo a)
forall a b. (a -> b) -> a -> b
$ IORef (PointInfo a) -> IO (PointInfo a)
forall a. IORef a -> IO a
readIORef IORef (PointInfo a)
ref
  case PointInfo a
pInfo of
    Info IORef Word32
_ IORef a
descRef -> IO () -> m ()
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> m ()) -> IO () -> m ()
forall a b. (a -> b) -> a -> b
$ IORef a -> a -> IO ()
forall a. IORef a -> a -> IO ()
writeIORef IORef a
descRef a
newDesc
    Link (Pt IORef (PointInfo a)
ref1) -> do
      PointInfo a
link' <- IO (PointInfo a) -> m (PointInfo a)
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO (PointInfo a) -> m (PointInfo a))
-> IO (PointInfo a) -> m (PointInfo a)
forall a b. (a -> b) -> a -> b
$ IORef (PointInfo a) -> IO (PointInfo a)
forall a. IORef a -> IO a
readIORef IORef (PointInfo a)
ref1
      case PointInfo a
link' of
        Info IORef Word32
_ IORef a
descRef -> IO () -> m ()
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> m ()) -> IO () -> m ()
forall a b. (a -> b) -> a -> b
$ IORef a -> a -> IO ()
forall a. IORef a -> a -> IO ()
writeIORef IORef a
descRef a
newDesc
        Link Point a
_ -> do
          Point a
newPoint <- Point a -> m (Point a)
forall (m :: * -> *) a. MonadIO m => Point a -> m (Point a)
repr Point a
point
          Point a -> a -> m ()
forall (m :: * -> *) a. MonadIO m => Point a -> a -> m ()
set Point a
newPoint a
newDesc


modify :: MonadIO m => Point a -> (a -> a) -> m ()
modify :: Point a -> (a -> a) -> m ()
modify point :: Point a
point@(Pt IORef (PointInfo a)
ref) a -> a
func = do
  PointInfo a
pInfo <- IO (PointInfo a) -> m (PointInfo a)
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO (PointInfo a) -> m (PointInfo a))
-> IO (PointInfo a) -> m (PointInfo a)
forall a b. (a -> b) -> a -> b
$ IORef (PointInfo a) -> IO (PointInfo a)
forall a. IORef a -> IO a
readIORef IORef (PointInfo a)
ref
  case PointInfo a
pInfo of
    Info IORef Word32
_ IORef a
descRef -> IO () -> m ()
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> m ()) -> IO () -> m ()
forall a b. (a -> b) -> a -> b
$ IORef a -> (a -> a) -> IO ()
forall a. IORef a -> (a -> a) -> IO ()
modifyIORef' IORef a
descRef a -> a
func
    Link (Pt IORef (PointInfo a)
ref1) -> do
      PointInfo a
link' <- IO (PointInfo a) -> m (PointInfo a)
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO (PointInfo a) -> m (PointInfo a))
-> IO (PointInfo a) -> m (PointInfo a)
forall a b. (a -> b) -> a -> b
$ IORef (PointInfo a) -> IO (PointInfo a)
forall a. IORef a -> IO a
readIORef IORef (PointInfo a)
ref1
      case PointInfo a
link' of
        Info IORef Word32
_ IORef a
descRef -> IO () -> m ()
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> m ()) -> IO () -> m ()
forall a b. (a -> b) -> a -> b
$ IORef a -> (a -> a) -> IO ()
forall a. IORef a -> (a -> a) -> IO ()
modifyIORef' IORef a
descRef a -> a
func
        Link Point a
_ -> do
          Point a
newPoint <- Point a -> m (Point a)
forall (m :: * -> *) a. MonadIO m => Point a -> m (Point a)
repr Point a
point
          Point a -> (a -> a) -> m ()
forall (m :: * -> *) a. MonadIO m => Point a -> (a -> a) -> m ()
modify Point a
newPoint a -> a
func


union :: (MonadIO m, MonadFail m) => Point a -> Point a -> a -> m ()
union :: Point a -> Point a -> a -> m ()
union Point a
p1 Point a
p2 a
newDesc = do
  point1 :: Point a
point1@(Pt IORef (PointInfo a)
ref1) <- Point a -> m (Point a)
forall (m :: * -> *) a. MonadIO m => Point a -> m (Point a)
repr Point a
p1
  point2 :: Point a
point2@(Pt IORef (PointInfo a)
ref2) <- Point a -> m (Point a)
forall (m :: * -> *) a. MonadIO m => Point a -> m (Point a)
repr Point a
p2

  Info IORef Word32
w1 IORef a
d1 <- IO (PointInfo a) -> m (PointInfo a)
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO (PointInfo a) -> m (PointInfo a))
-> IO (PointInfo a) -> m (PointInfo a)
forall a b. (a -> b) -> a -> b
$ IORef (PointInfo a) -> IO (PointInfo a)
forall a. IORef a -> IO a
readIORef IORef (PointInfo a)
ref1
  Info IORef Word32
w2 IORef a
d2 <- IO (PointInfo a) -> m (PointInfo a)
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO (PointInfo a) -> m (PointInfo a))
-> IO (PointInfo a) -> m (PointInfo a)
forall a b. (a -> b) -> a -> b
$ IORef (PointInfo a) -> IO (PointInfo a)
forall a. IORef a -> IO a
readIORef IORef (PointInfo a)
ref2

  if Point a
point1 Point a -> Point a -> Bool
forall a. Eq a => a -> a -> Bool
== Point a
point2
    then IO () -> m ()
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> m ()) -> IO () -> m ()
forall a b. (a -> b) -> a -> b
$ IORef a -> a -> IO ()
forall a. IORef a -> a -> IO ()
writeIORef IORef a
d1 a
newDesc
    else do
      Word32
weight1 <- IO Word32 -> m Word32
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO Word32 -> m Word32) -> IO Word32 -> m Word32
forall a b. (a -> b) -> a -> b
$ IORef Word32 -> IO Word32
forall a. IORef a -> IO a
readIORef IORef Word32
w1
      Word32
weight2 <- IO Word32 -> m Word32
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO Word32 -> m Word32) -> IO Word32 -> m Word32
forall a b. (a -> b) -> a -> b
$ IORef Word32 -> IO Word32
forall a. IORef a -> IO a
readIORef IORef Word32
w2

      let !newWeight :: Word32
newWeight = Word32
weight1 Word32 -> Word32 -> Word32
forall a. Num a => a -> a -> a
+ Word32
weight2

      if Word32
weight1 Word32 -> Word32 -> Bool
forall a. Ord a => a -> a -> Bool
>= Word32
weight2
        then do
          IO () -> m ()
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> m ()) -> IO () -> m ()
forall a b. (a -> b) -> a -> b
$ IORef (PointInfo a) -> PointInfo a -> IO ()
forall a. IORef a -> a -> IO ()
writeIORef IORef (PointInfo a)
ref2 (Point a -> PointInfo a
forall a. Point a -> PointInfo a
Link Point a
point1)
          IO () -> m ()
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> m ()) -> IO () -> m ()
forall a b. (a -> b) -> a -> b
$ IORef Word32 -> Word32 -> IO ()
forall a. IORef a -> a -> IO ()
writeIORef IORef Word32
w1 Word32
newWeight
          IO () -> m ()
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> m ()) -> IO () -> m ()
forall a b. (a -> b) -> a -> b
$ IORef a -> a -> IO ()
forall a. IORef a -> a -> IO ()
writeIORef IORef a
d1 a
newDesc
        else do
          IO () -> m ()
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> m ()) -> IO () -> m ()
forall a b. (a -> b) -> a -> b
$ IORef (PointInfo a) -> PointInfo a -> IO ()
forall a. IORef a -> a -> IO ()
writeIORef IORef (PointInfo a)
ref1 (Point a -> PointInfo a
forall a. Point a -> PointInfo a
Link Point a
point2)
          IO () -> m ()
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> m ()) -> IO () -> m ()
forall a b. (a -> b) -> a -> b
$ IORef Word32 -> Word32 -> IO ()
forall a. IORef a -> a -> IO ()
writeIORef IORef Word32
w2 Word32
newWeight
          IO () -> m ()
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> m ()) -> IO () -> m ()
forall a b. (a -> b) -> a -> b
$ IORef a -> a -> IO ()
forall a. IORef a -> a -> IO ()
writeIORef IORef a
d2 a
newDesc


equivalent :: MonadIO m => Point a -> Point a -> m Bool
equivalent :: Point a -> Point a -> m Bool
equivalent Point a
p1 Point a
p2 = do
  Point a
v1 <- Point a -> m (Point a)
forall (m :: * -> *) a. MonadIO m => Point a -> m (Point a)
repr Point a
p1
  Point a
v2 <- Point a -> m (Point a)
forall (m :: * -> *) a. MonadIO m => Point a -> m (Point a)
repr Point a
p2
  Bool -> m Bool
forall (m :: * -> *) a. Monad m => a -> m a
return (Point a
v1 Point a -> Point a -> Bool
forall a. Eq a => a -> a -> Bool
== Point a
v2)


redundant :: MonadIO m => Point a -> m Bool
redundant :: Point a -> m Bool
redundant (Pt IORef (PointInfo a)
ref) = do
  PointInfo a
pInfo <- IO (PointInfo a) -> m (PointInfo a)
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO (PointInfo a) -> m (PointInfo a))
-> IO (PointInfo a) -> m (PointInfo a)
forall a b. (a -> b) -> a -> b
$ IORef (PointInfo a) -> IO (PointInfo a)
forall a. IORef a -> IO a
readIORef IORef (PointInfo a)
ref
  case PointInfo a
pInfo of
    Info IORef Word32
_ IORef a
_ -> Bool -> m Bool
forall (m :: * -> *) a. Monad m => a -> m a
return Bool
False
    Link Point a
_ -> Bool -> m Bool
forall (m :: * -> *) a. Monad m => a -> m a
return Bool
True