import { roundNumber } from '../../shared/utils/number-utils';
import { isCoinsEquals } from '../currency/currency-service';
import { CoinsAmount } from '../currency/currency-types';
import { getConnectedPools } from './amm.service';
import { AmmParams, Pool } from './types';

export const estimateSwapAmountIn = (
    pools: Pool[],
    params: AmmParams,
    tokenIn: CoinsAmount,
    tokenOut: CoinsAmount,
): number => {
    tokenIn = { ...tokenIn, amount: tokenIn.amount * (1 - params.takerFee) };
    const connectedPools = getConnectedPools(pools, params, tokenIn, tokenOut);
    for (const pool of connectedPools) {
        const [ assetIn, assetOut ] = isCoinsEquals(pool.assets[0], tokenIn) ? pool.assets : [ ...pool.assets ].reverse();
        const tokenOutAmount = calcOutAmtGivenIn(tokenIn, assetIn, assetOut, pool.swapFee);
        tokenIn = { ...assetOut, amount: tokenOutAmount };
    }
    return tokenIn.amount;
};

export const estimateSwapAmountOut = (
    pools: Pool[],
    params: AmmParams,
    tokenIn: CoinsAmount,
    tokenOut: CoinsAmount,
): { amountIn: number, isAmountOutFixed: boolean } => {
    let isAmountOutFixed = false;
    const connectedPools = getConnectedPools(pools, params, tokenIn, tokenOut).reverse();
    for (const pool of connectedPools) {
        const [ assetIn, assetOut ] = isCoinsEquals(pool.assets[1], tokenOut) ? pool.assets : [ ...pool.assets ].reverse();
        const { tokenInAmount, fixedTokenOutAmount } = calcInAmtGivenOut(tokenOut, assetIn, assetOut, pool.swapFee);
        isAmountOutFixed = isAmountOutFixed || tokenOut.amount !== fixedTokenOutAmount;
        tokenOut = { ...assetIn, amount: tokenInAmount };
    }
    return { amountIn: tokenOut.amount / (1 - params.takerFee), isAmountOutFixed };
};

export const calcJoinPoolNoSwapShares = (pool: Pool, tokensIn: CoinsAmount[]): number => {
    let minShareRatio = Number.MAX_VALUE;
    tokensIn.forEach((token) => {
        const poolAmount = pool.assets.find((asset) => isCoinsEquals(asset, token))?.amount || 0;
        const shareRatio = token.amount / poolAmount;
        minShareRatio = Math.min(minShareRatio, shareRatio);
    });
    return Math.floor(minShareRatio * Number(pool.totalShares));
};

export const calcJoinPoolShares = (pool: Pool, tokenIn: CoinsAmount): number => {
    const tokenInPoolAssetAmount = pool.assets.find((asset) => isCoinsEquals(asset, tokenIn))?.amount || 0;
    if (!tokenInPoolAssetAmount) {
        return 0;
    }
    const tokenAmountInAfterFee = tokenIn.amount * (1 - pool.swapFee) / 2;
    const poolAmountOut = -solveConstantFunctionInvariant(
        tokenInPoolAssetAmount + tokenAmountInAfterFee,
        tokenInPoolAssetAmount,
        Number(pool.totalShares),
    );
    return Math.floor(Math.max(0, poolAmountOut));
};

const calcOutAmtGivenIn = (
    tokenIn: CoinsAmount,
    poolAssetIn: CoinsAmount,
    poolAssetOut: CoinsAmount,
    swapFee: number,
): number => {
    const tokenAmountInAfterFee = tokenIn.amount * (1 - swapFee);
    const poolPostSwapInBalance = poolAssetIn.amount + tokenAmountInAfterFee;
    const tokenAmountOut = solveConstantFunctionInvariant(poolAssetIn.amount, poolPostSwapInBalance, poolAssetOut.amount);
    if (tokenAmountOut < 0) {
        return poolAssetIn.amount;
    }
    return roundNumber(tokenAmountOut, poolAssetOut.currency.decimals, true);
};

const calcInAmtGivenOut = (
    tokenOut: CoinsAmount,
    poolAssetIn: CoinsAmount,
    poolAssetOut: CoinsAmount,
    swapFee: number,
): { tokenInAmount: number, fixedTokenOutAmount: number } => {
    const fixedTokenOutAmount = Math.min(tokenOut.amount, 0.9999999999 * poolAssetOut.amount);
    const poolPostSwapOutBalance = poolAssetOut.amount - fixedTokenOutAmount;
    const tokenAmountIn = -solveConstantFunctionInvariant(poolAssetOut.amount, poolPostSwapOutBalance, poolAssetIn.amount);
    const tokenAmountInBeforeFee = tokenAmountIn / (1 - swapFee);
    return { tokenInAmount: roundNumber(tokenAmountInBeforeFee, poolAssetIn.currency.decimals, undefined, true), fixedTokenOutAmount };
};

const solveConstantFunctionInvariant = (
    tokenBalanceFixedBefore: number,
    tokenBalanceFixedAfter: number,
    tokenBalanceUnknownBefore: number,
) => tokenBalanceUnknownBefore * (1 - (tokenBalanceFixedBefore / tokenBalanceFixedAfter));
