본문 바로가기
머신러닝&딥러닝/Tensorflow

(Tensorflow) XOR Neural Network 만들기

by J_Remind 2019. 1. 7.

이번 포스트는 Tensorflow로 XOR 기능을 수행하는 Neural Network를 만들 것입니다.


활성화 함수로는 Sigmoid를 사용하고 손실 함수로는 교차 엔트로피(Cross Entropy) 사용했습니다.


XOR Neural Network


XOR Neural Network 그림

XOR Neural Network 코드

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
import tensorflow as tf
import numpy as np
 
learning_rate = 0.1
 
x_data = [[00],
          [01],
          [10],
          [11]]
y_data = [[0],
          [1],
          [1],
          [0]]
x_data = np.array(x_data, dtype=np.float32)
y_data = np.array(y_data, dtype=np.float32)
 
= tf.placeholder(tf.float32, [None, 2])
= tf.placeholder(tf.float32, [None, 1])
 
W1 = tf.Variable(tf.random_normal([22]), name='weight1')
b1 = tf.Variable(tf.random_normal([2]), name='bias1')
layer1 = tf.sigmoid(tf.matmul(X, W1) + b1)
 
 
W2 = tf.Variable(tf.random_normal([21]), name='weight2')
b2 = tf.Variable(tf.random_normal([1]), name='bias2')
hypothesis = tf.sigmoid(tf.matmul(layer1, W2) + b2)
 
# cost/loss function
# 교차엔트로피
cost = -tf.reduce_mean(Y * tf.log(hypothesis) + (1 - Y) *
                       tf.log(1 - hypothesis))
train = tf.train.GradientDescentOptimizer(learning_rate = learning_rate).minimize(cost)
 
# hypothesis>0.5가 True=1, False=0
predicted = tf.cast(hypothesis > 0.5, dtype=tf.float32)
accuracy = tf.reduce_mean(tf.cast(tf.equal(predicted, Y), dtype=tf.float32))
 
 
 
with tf.Session() as sess:
    # 초기화
    sess.run(tf.global_variables_initializer())
 
    #학습
    for step in range(10001):
        sess.run(train, feed_dict={X: x_data, Y: y_data})
        if step % 100 == 0:
            print(step, sess.run(cost, feed_dict={
                  X: x_data, Y: y_data}), sess.run([W1, W2]))
 
    # Accuracy report
    hypo, c, acc = sess.run([hypothesis, predicted, accuracy],
                       feed_dict={X: x_data, Y: y_data})
    print("\nHypothesis: ", hypo, "\nCorrect: ", c, "\nAccuracy: ", acc)
cs