현재 텐서 흐름을 학습 중입니다. 나는 softmax 모델을 사용하여 분류 모델을 만들려고 노력했다. 이 프로그램에서는 CSV 파일에서 두 개의 열 왼쪽에 교육 데이터 세트를 설정하고 두 열의 오른쪽에는 두 개의 레이블을 설정합니다. 예를 들면 :분류에 따른 텐서 흐름
데이터 1이, 데이터 2가, LABEL1는 라벨 2
234, 23, 1, 0 # 234 정도로 LABEL1 1로 태그되고, 23보다 큰, 및 라벨 2 0
156, 113, 1로 태그, 0
1, 4, 0, 1
위와 같이 설정된 학습 데이터에서 가장 큰 수 기준으로 테스트 데이터를 분류하고 비용 값을 거의 0으로 수렴합니다.
그러나 데이터 세트를 변경하여 테스트 데이터를 짝수로 분류하는 것을 목표로하는 짝수 번호로 레이블을 변경하면 모델이 실패하고 비용은 변동합니다. 데이터 세트는 다음과 같다 :
데이터 1, 데이터 2가, LABEL1이 라벨 2
24, 35, 1, 0 # 24은 짝수이므로 LABEL1 1로 태그되고, 라벨 2 0
(156), (553)로 태그, 1, 0
1, 4, 0, 1
나는 프로그램에 문제가 있습니까? 짝수인데도 불구하고 데이터 세트 중에서 가장 큰 숫자를 구별하는 이유는 무엇입니까? 모두에게 감사드립니다! 당신은 당신의 네트워크는 선형 분리 문제를 해결할 수 있음을 의미 뉴런의 단일 층을 가지고
import tensorflow as tf
import os
import numpy as np
def next_batch(num, data, labels):
idx = np.arange(0 , len(data))
np.random.shuffle(idx)
idx = idx[:num]
data_shuffle = [data[ i] for i in idx]
labels_shuffle = [labels[ i] for i in idx]
return np.asarray(data_shuffle), np.asarray(labels_shuffle)
dir_path = os.path.dirname(os.path.realpath(__file__))
filename = dir_path + "/classification.csv"
x = tf.placeholder(tf.float32, [None, 2])
y = tf.placeholder(tf.float32, [None, 2])
W = tf.Variable(tf.zeros([2, 2]))
b = tf.Variable(tf.zeros([2]))
pred =tf.add(tf.matmul(x, W),b)
cost=tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=pred,labels=y))
optimizer = tf.train.GradientDescentOptimizer(0.1).minimize(cost)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
with open(filename) as inf:
# Skip header
next(inf)
result_array = np.shape(4)
for line in inf:
data1, data2,label1,label2= line.strip().split(",")
data1 = float(data1)
data2 = float(data2)
label1 = int(label1)
label2 = int(label2)
result_array = np.append(result_array, (data1,data2,label1,label2))
result_array=result_array.reshape(1000,4)
k=result_array[:,2:4]
gg=result_array[:,0:2]
for i in range(0,3000):
batch_xs, batch_ys = next_batch(200,gg,k)
h,cos=sess.run([optimizer, cost], feed_dict={x: batch_xs,y:batch_ys})
print(cos)
print(sess.run(pred,feed_dict={x:[[5,2],[4,9],[4,3],[5,2],[3,6],[30,21],[32,20],[3,4]]})) #testing data