uniform float3 LightPos;
uniform float3 LightDir0;
uniform float3 LightDir1;
uniform float3 LightDir2;
uniform float3 LightDir3;
uniform float3 LightDir4;
uniform float3 LightDir5;
uniform float LightTermMinValue;

struct VS_OUTPUT
{
    float4 Position         : POSITION;   // vertex position
    float4 TexCoord         : TEXCOORD0;  // vertex interpolation value
    float3 Normal           : TEXCOORD1;  // vertex normal
    float3 WorldPosition    : TEXCOORD2;  // vertex world position
    float3 Light0Map        : TEXCOORD3;  // light0 shadow map
    float3 Light1Map        : TEXCOORD4;  // light1 shadow map
    float3 Light2Map        : TEXCOORD5;  // light2 shadow map
    float3 Light3Map        : TEXCOORD6;  // light3 shadow map
    float3 Light4Map        : TEXCOORD7;  // light4 shadow map
    float3 Light5Map        : TEXCOORD8;  // light5 shadow map
};

struct PS_OUTPUT
{
    float4 Color    : COLOR0;
};

sampler TextureSampler = sampler_state
{ 
    MinFilter = LINEAR; //these do nothing
    MagFilter = LINEAR;
    MipFilter = LINEAR;
};

sampler Light0Sampler = sampler_state
{ 
    MinFilter = LINEAR; //these do nothing
    MagFilter = LINEAR;
    MipFilter = LINEAR;
};

sampler Light1Sampler = sampler_state
{ 
    MinFilter = LINEAR; //these do nothing
    MagFilter = LINEAR;
    MipFilter = LINEAR;
};

sampler Light2Sampler = sampler_state
{ 
    MinFilter = LINEAR; //these do nothing
    MagFilter = LINEAR;
    MipFilter = LINEAR;
};

sampler Light3Sampler = sampler_state
{ 
    MinFilter = LINEAR; //these do nothing
    MagFilter = LINEAR;
    MipFilter = LINEAR;
};

sampler Light4Sampler = sampler_state
{ 
    MinFilter = LINEAR; //these do nothing
    MagFilter = LINEAR;
    MipFilter = LINEAR;
};

sampler Light5Sampler = sampler_state
{ 
    MinFilter = LINEAR; //these do nothing
    MagFilter = LINEAR;
    MipFilter = LINEAR;
};

float ComputeLightCoverageClassic(sampler2D Sampler, float2 TexBase, float LightCompare)
{
    float LightValue = tex2D(Sampler, TexBase).x;
    float Diff = LightCompare - LightValue;
    if(Diff > 0.0)
    {
        return 0.0;
    }
    else
    {
        return 1.0;
    }
}

float ComputeLightCoverageVariance(sampler2D Sampler, float2 TexBase, float LightCompare)
{
    float2 LightValue = tex2D(Sampler, TexBase).xy;
    float Variance = LightValue.y - LightValue.x * LightValue.x;
    float Diff = LightCompare - LightValue.x;
    if(Diff > 0.0)
    {
        return Variance / (Variance + Diff * Diff);
    }
    else
    {
        return 1.0;
    }
}

float ComputeShadowTerm(float Weight, sampler2D Sampler, float3 LightMap)
{
    if(Weight <= 0.0)
    {
        return 0.0;
    }
    else
    {
        return max(ComputeLightCoverageVariance(Sampler, LightMap.xy, LightMap.z - 0.03), 0.4) * Weight;
    }
}

float ComputeLightFrustrumValue(float3 LightMap, float DotProduct)
{
    if(LightMap.z <= 0.02 || DotProduct >= -0.5)
    {
        return 0.0;
    }
    float XDiff = abs(LightMap.x - 0.5);
    float YDiff = abs(LightMap.y - 0.5);
    float MaxDiff = max(XDiff, YDiff);
    return max(0.0, 0.5 - MaxDiff);
}

PS_OUTPUT PShaderEntry( VS_OUTPUT Input )
{
    PS_OUTPUT Output;
    
    float3 NormalizedNormal = normalize(Input.Normal);
    float4 NormalColor = float4(NormalizedNormal.x * 0.5 + 0.5, NormalizedNormal.y * 0.5 + 0.5, NormalizedNormal.z * 0.5 + 0.5, 0.0);
    float3 LightVector = normalize(Input.WorldPosition - LightPos);

    float LightDot = dot(NormalizedNormal, LightVector);
    float BaseLightTerm = max(LightTermMinValue, LightDot);

    float LightWeights[6];
    LightWeights[0] = ComputeLightFrustrumValue(Input.Light0Map, dot(LightVector, LightDir0));
    LightWeights[1] = ComputeLightFrustrumValue(Input.Light1Map, dot(LightVector, LightDir1));
    LightWeights[2] = ComputeLightFrustrumValue(Input.Light2Map, dot(LightVector, LightDir2));
    LightWeights[3] = ComputeLightFrustrumValue(Input.Light3Map, dot(LightVector, LightDir3));
    LightWeights[4] = ComputeLightFrustrumValue(Input.Light4Map, dot(LightVector, LightDir4));
    LightWeights[5] = ComputeLightFrustrumValue(Input.Light5Map, dot(LightVector, LightDir5));

    float ShadowLightTerm = 0.0;
    ShadowLightTerm += ComputeShadowTerm(LightWeights[0], Light0Sampler, Input.Light0Map);
    ShadowLightTerm += ComputeShadowTerm(LightWeights[1], Light1Sampler, Input.Light1Map);
    ShadowLightTerm += ComputeShadowTerm(LightWeights[2], Light2Sampler, Input.Light2Map);
    ShadowLightTerm += ComputeShadowTerm(LightWeights[3], Light3Sampler, Input.Light3Map);
    ShadowLightTerm += ComputeShadowTerm(LightWeights[4], Light4Sampler, Input.Light4Map);
    ShadowLightTerm += ComputeShadowTerm(LightWeights[5], Light5Sampler, Input.Light5Map);

    float WeightNormalization = 1.0 / (LightWeights[0] + LightWeights[1] + LightWeights[2] + LightWeights[3] + LightWeights[4] + LightWeights[5]);
    ShadowLightTerm *= WeightNormalization;

    float LightTerm = min(ShadowLightTerm, BaseLightTerm);
    float4 TextureColor = tex2D(TextureSampler, Input.TexCoord);
    float4 LightColor = float4(LightTerm, LightTerm, LightTerm, BaseLightTerm);
    Output.Color = TextureColor * LightColor;
    return Output;
}