Friday 10 February 2012

ID3 Code With Missing Value Attributes Handling

The algorithm ID3 (Quinlan) uses the method top-down induction of decision trees. Given a set of classified examples a decision tree is induced, biased by the information gain measure, which heuristically leads to small trees. The examples are given in attribute-value representation. The set of possible classes is finite. Only tests, that split the set of instances of the underlying example languages depending on the value of a single attribute are supported.

You need to follow the following steps.
  1. Create console project in visual studio(2010)
  2. Create a file data.txt in your debug folder in your project folder
  3. OutLook(Sunny:Overcast:Rain),Temperature(Hot:Mild:Cold),Humidity(High:Normal),Windy(True:False),Decision(true:false)
    Sunny,Hot,High,False,false
    Sunny,Hot,High,True,false
    Overcast,Hot,High,False,true
    Rain,Mild,High,False,true
    Rain,Cool,Normal,False,true
    Rain,Cool,Normal,True,false
    Overcast,Cool,Normal,True,true
    Sunny,Mild,High,False,false
    Sunny,Cool,Normal,False,true
    Rain,Mild,Normal,False,true
    Sunny,Mild,Normal,False,true
    Overcast,Mild,High,True,true
    Overcast,Hot,Normal,False,true
  4. Now open the program.cs in your project and delete all existing code and paste the following code.
  5. /********************************************************************************************
     *                       Id3 Implementation
     *                       Author: M.I.A & Noshina Tariq
     *                     
    */

    using System;
    using System.Collections;
    using System.Data;
    using System.IO;

    namespace ExemploID3
    {
        // this is a class which will used to hold all the attributes and their possible values.
        public class Attribute
        {
            ArrayList mValues;
            string mName;
            object mLabel;

            // constructor
            public Attribute(string name, string[] values)
            {
                mName = name;
                mValues = new ArrayList(values);
                mValues.Sort();
            }

            public Attribute(object Label)
            {
                mLabel = Label;
                mName = string.Empty;
                mValues = null;
            }

            // getter, returns the name of attribute.
            public string AttributeName
            {
                get
                {
                    return mName;
                }
            }
            // returns the string array of attribute values
            public string[] values
            {
                get
                {
                    if (mValues != null)
                        return (string[])mValues.ToArray(typeof(string));
                    else
                        return null;
                }
            }

            // this is will validate the value used in instance of an attribute
            public bool isValidValue(string value)
            {
                return indexValue(value) >= 0;
            }

            // actuall validating function
            public int indexValue(string value)
            {
                if (mValues != null)
                    return mValues.BinarySearch(value);
                else
                    return -1;
            }

            public override string ToString()
            {
                if (mName != string.Empty)
                {
                    return mName;
                }
                else
                {
                    return mLabel.ToString();
                }
            }
        }

        // to hold the node  and its childrens within a tree
        public class TreeNode
        {
            private ArrayList mChilds = null;
            private Attribute mAttribute;

            // constructor
            public TreeNode(Attribute attribute)
            {
                if (attribute.values != null)
                {
                    mChilds = new ArrayList(attribute.values.Length);
                    for (int i = 0; i < attribute.values.Length; i++)
                        mChilds.Add(null);
                }
                else
                {
                    mChilds = new ArrayList(1);
                    mChilds.Add(null);
                }
                mAttribute = attribute;
            }


            public void AddTreeNode(TreeNode treeNode, string ValueName)
            {
                int index = mAttribute.indexValue(ValueName);
                mChilds[index] = treeNode;
            }

            public int totalChilds
            {
                get
                {
                    return mChilds.Count;
                }
            }

            public TreeNode getChild(int index)
            {
                return (TreeNode)mChilds[index];
            }

            public Attribute attribute
            {
                get
                {
                    return mAttribute;
                }
            }

            public TreeNode getChildByBranchName(string branchName)
            {
                int index = mAttribute.indexValue(branchName);
                return (TreeNode)mChilds[index];
            }
        }

        // actual Id3 Implemetion
        public class DecisionTreeID3
        {
            private DataTable mSamples;
            private int mTotalPositives = 0;
            private int mTotal = 0;
            private string mTargetAttribute = "result";
            private double mEntropySet = 0.0;

            // calculates total positive instance in the data
            private int countTotalPositives(DataTable samples)
            {
                int result = 0;

                foreach (DataRow aRow in samples.Rows)
                {
                    if (aRow[mTargetAttribute].ToString() == "true")
                        result++;
                }

                return result;
            }

            // calculates entropy
            private double calcEntropy(int positives, int negatives)
            {
                int total = positives + negatives;
                double ratioPositive = (double)positives / total;
                double ratioNegative = (double)negatives / total;

                if (ratioPositive != 0)
                    ratioPositive = -(ratioPositive) * System.Math.Log(ratioPositive, 2);
                if (ratioNegative != 0)
                    ratioNegative = -(ratioNegative) * System.Math.Log(ratioNegative, 2);

                double result = ratioPositive + ratioNegative;

                return result;
            }

            // this will calculates positive and negitive instance after root is selected.
            private void getValuesToAttribute(DataTable samples, Attribute attribute, string value, out int positives, out int negatives)
            {
                positives = 0;
                negatives = 0;

                foreach (DataRow aRow in samples.Rows)
                {
                    if (((string)aRow[attribute.AttributeName] == value))
                        if (aRow[mTargetAttribute].ToString() == "true")
                            positives++;
                        else
                            negatives++;
                }
            }

            // calculates gain
            private double gain(DataTable samples, Attribute attribute)
            {
                string[] values = attribute.values;
                double sum = 0.0;

                for (int i = 0; i < values.Length; i++)
                {
                    int positives, negatives;

                    positives = negatives = 0;

                    getValuesToAttribute(samples, attribute, values[i], out positives, out negatives);

                    double entropy = calcEntropy(positives, negatives);
                    sum += -(double)(positives + negatives) / mTotal * entropy;
                }
                return mEntropySet + sum;
            }

            // calculates gain for all available attributs and then return attribut having maximum info.gain
            private Attribute getBestAttribute(DataTable samples, Attribute[] attributes)
            {
                double maxGain = 0.0;
                Attribute result = null;

                foreach (Attribute attribute in attributes)
                {
                    double aux = gain(samples, attribute);
                    if (aux > maxGain)
                    {
                        maxGain = aux;
                        result = attribute;
                    }
                }
                return result;
            }

            // checks that data has all the positive samples.
            private bool allSamplesPositives(DataTable samples, string targetAttribute)
            {
                foreach (DataRow row in samples.Rows)
                {
                    String x=row[targetAttribute].ToString();
                    if (x == "false")
                        return false;
                }

                return true;
            }

            // checks that data has all the negitive samples.
            private bool allSamplesNegatives(DataTable samples, string targetAttribute)
            {
                foreach (DataRow row in samples.Rows)
                {
                    if (row[targetAttribute].ToString() == "true")
                        return false;
                }

                return true;
            }

            // used to get distinct values of attributs in the data.
            private ArrayList getDistinctValues(DataTable samples, string targetAttribute)
            {
                ArrayList distinctValues = new ArrayList(samples.Rows.Count);

                foreach (DataRow row in samples.Rows)
                {
                    if (distinctValues.IndexOf(row[targetAttribute]) == -1)
                        distinctValues.Add(row[targetAttribute]);
                }

                return distinctValues;
            }

            // returns most common value of attributs from given instance
            private object getMostCommonValue(DataTable samples, string targetAttribute)
            {
                ArrayList distinctValues = getDistinctValues(samples, targetAttribute);
                int[] count = new int[distinctValues.Count];

                foreach (DataRow row in samples.Rows)
                {
                    int index = distinctValues.IndexOf(row[targetAttribute]);
                    count[index]++;
                }

                int MaxIndex = 0;
                int MaxCount = 0;

                for (int i = 0; i < count.Length; i++)
                {
                    if (count[i] > MaxCount)
                    {
                        MaxCount = count[i];
                        MaxIndex = i;
                    }
                }

                return distinctValues[MaxIndex];
            }

            // creates tree after calculations
            private TreeNode internalMountTree(DataTable samples, string targetAttribute, Attribute[] attributes)
            {
                if (allSamplesPositives(samples, targetAttribute) == true)
                    return new TreeNode(new Attribute(true));

                if (allSamplesNegatives(samples, targetAttribute) == true)
                    return new TreeNode(new Attribute(false));

                if (attributes.Length == 0)
                    return new TreeNode(new Attribute(getMostCommonValue(samples, targetAttribute)));

                mTotal = samples.Rows.Count;
                mTargetAttribute = targetAttribute;
                mTotalPositives = countTotalPositives(samples);

                mEntropySet = calcEntropy(mTotalPositives, mTotal - mTotalPositives);

                Attribute bestAttribute = getBestAttribute(samples, attributes);

                TreeNode root = new TreeNode(bestAttribute);

                DataTable aSample = samples.Clone();

                foreach (string value in bestAttribute.values)
                {
                    aSample.Rows.Clear();

                    DataRow[] rows = samples.Select(bestAttribute.AttributeName + " = " + "'" + value + "'");

                    foreach (DataRow row in rows)
                    {
                        aSample.Rows.Add(row.ItemArray);
                    }
                  
                    ArrayList aAttributes = new ArrayList(attributes.Length - 1);
                    for (int i = 0; i < attributes.Length; i++)
                    {
                        if (attributes[i].AttributeName != bestAttribute.AttributeName)
                            aAttributes.Add(attributes[i]);
                    }
                  
                    if (aSample.Rows.Count == 0)
                    {
                        return new TreeNode(new Attribute(getMostCommonValue(aSample, targetAttribute)));
                    }
                    else
                    {
                        DecisionTreeID3 dc3 = new DecisionTreeID3();
                        TreeNode ChildNode = dc3.mountTree(aSample, targetAttribute, (Attribute[])aAttributes.ToArray(typeof(Attribute)));
                        root.AddTreeNode(ChildNode, value);
                    }
                }

                return root;
            }

            public TreeNode mountTree(DataTable samples, string targetAttribute, Attribute[] attributes)
            {
                mSamples = samples;
                return internalMountTree(mSamples, targetAttribute, attributes);
            }
        }

        class ID3Sample
        {
            public static DataTable dt;

            //prints tree on screen
            public static void printNode(TreeNode root, string tabs)
            {
                Console.WriteLine(tabs + '|' + root.attribute + '|');

                if (root.attribute.values != null)
                {
                    for (int i = 0; i < root.attribute.values.Length; i++)
                    {
                        Console.WriteLine(tabs + "\t" + "<" + root.attribute.values[i] + ">");
                        TreeNode childNode = root.getChildByBranchName(root.attribute.values[i]);
                        printNode(childNode, "\t" + tabs);
                    }
                }
            }

            // handels missing values handel=1 for method one and handel=2 for second method
            static void missingValue(int handel)
            {
                //Boolean isMissing = false;
                //int row = -1, col = -1;

                for (int i = 0; i < dt.Rows.Count; i++)
                {
                    for (int j = 0; j < dt.Columns.Count; j++)
                    {
                        if (dt.Rows[i][j].ToString()=="")
                        {
                            if (handel==1)
                            {
                                dt.Rows.RemoveAt(i);
                                i--;
                                j = dt.Columns.Count;
                            }
                            else if (handel == 2)
                            {
                                ArrayList attV = new ArrayList();
                                ArrayList rep = new ArrayList();
                                int max_ind=-1;

                                for (int k = 0; k < dt.Rows.Count; k++)
                                {
                                    if (!attV.Contains(dt.Rows[k][j]) && dt.Rows[k][j] != "")
                                    {
                                        attV.Add(dt.Rows[k][j]);
                                        rep.Add(1);

                                        if (max_ind == -1)
                                        {
                                            max_ind = 0;
                                        }

                                    }
                                    else if(attV.Contains(dt.Rows[k][j]))
                                    {
                                        int c_ind=attV.IndexOf(dt.Rows[k][j]);

                                        rep[c_ind] = ((int)rep[c_ind]) + 1;

                                        if ((int)rep[max_ind] < (int)rep[c_ind])
                                            max_ind = c_ind;
                                    }
                                }
                                if (max_ind != -1)
                                {
                                   dt.Rows[i][j] = attV[max_ind].ToString();
                                }
                            }
                        }
                    }
                }
            }

            [STAThread]
            static void Main(string[] args)
            {

                StreamReader sr = new StreamReader("data.txt");
                String line;
                dt = new DataTable("Data");
                int lineNum = 1;
                char[] sep1 = { ',' }, sep2 = { ':' };
                //DataTable dt = new DataTable("data");
              
                Attribute[] attributes=null;
              
                String ClassLabel="";

                while (!sr.EndOfStream)
                {
                    line = sr.ReadLine();
                    String[] parts = line.Split(sep1);

                    if (lineNum == 1)
                    {

                        int at_num = parts.Length;
                        attributes = new Attribute[at_num - 1];

                        for (int x = 0; x < at_num; x++)
                        {
                            String atr_name = parts[x].Substring(0, parts[x].IndexOf('('));
                            String[] atr_possi = parts[x].Substring(parts[x].IndexOf('(') + 1, parts[x].Length - parts[x].IndexOf('(') - 2).Split(sep2);

                            DataColumn column = dt.Columns.Add(atr_name);
                            column.DataType = typeof(string);

                            if (x + 1 == at_num)
                                ClassLabel = atr_name;
                            else
                                attributes[x] = new Attribute(atr_name, atr_possi);


                        }

                    }
                    else
                    {
                        dt.Rows.Add(parts);
                    }

                    lineNum++;
                }
              
                missingValue(2);

                DecisionTreeID3 id3 = new DecisionTreeID3();
                TreeNode root = id3.mountTree(dt, ClassLabel, attributes);

                printNode(root, "  ");

              
            }
        }
    }
  6. Thats it... :-)

No comments:

Post a Comment