SVM for iris in R

library(e1071)
data(iris)

set.seed(12345) 
sample_idx <- sample(nrow(iris), nrow(iris) * 0.7)
train <- iris[sample_idx, ]
test <- iris[-sample_idx, -5]
lable <- iris[-sample_idx, 5]

# Using default parameters.
svm.model <- svm(Species ~ ., data=train, probability=TRUE)

# Getting prob and label.
predicted.prob <- attr(predict(svm.model, test, probability=TRUE), "probabilities")
predicted.label <- colnames(predicted.prob)[max.col(predicted.prob)]

# To check that the code for finding the max is right, uncomment the below.
# table(predicted.label, predict(svm.model, test))

# Estimate accuracy.
confusion_matrix <- table(predicted.label, lable)
confusion_matrix
classAgreement(confusion_matrix)$diag

# Tune parameters.
svm.tuned.model <- tune.svm(Species ~ .,
                            data=train,
                            type='C-classification',
                            kernel='radial',
                            cost=2^(-5:5),
                            gamma=2^(-5:5),
                            tunecontrol=tune.control(cross=10))$best.model

predicted.tuned.lable <- predict(svm.tuned.model, test)

# Estimate accuracy.
confusion_matrix.tuned <- table(predicted.tuned.lable, lable)
confusion_matrix.tuned
classAgreement(confusion_matrix.tuned)$diag

Here’s output.

Loading required package: class
               lable
predicted.label setosa versicolor virginica
     setosa         13          0         0
     versicolor      0         15         3
     virginica       0          0        14
[1] 0.9333333
                     lable
predicted.tuned.lable setosa versicolor virginica
           setosa         13          0         0
           versicolor      0         15         2
           virginica       0          0        15
[1] 0.9555556

Similar Posts:

Comments 5

  1. lee tae woo wrote:

    안녕하십니까
    항상 유용한 정보를 제공해 주셔서 정말 감사드립니다. 본 피드 내용들을 자주 참고하고 있는데요, 질문을 하나 드려도 될런지요.
    타겟이 binary인 자료에 예측모델로 svm을 이용하고 싶은데요 님께서 올리신 SVM for iris in R이라고 되있는 포스트에서 iris 데이터를 가지고 fitting을 시키셨는데요 본 프로그램에서 커널모수인 시그마 값이나 C값의 추정은 어떻게 이루어지는지 궁금합니다. 아니면 따로 추정을 해야하는지 궁금하네요..

    시간이 괜찮으시다면 제 메일로 답변 부탁드리구요, 본 포스트에 댓글로 남겨주시면 감사하겠습니다.

    감사합니다.

    Posted 19 Jul 2012 at 6:27 am
  2. Minkoo Seo wrote:

    덕분에 저도 정리를 제대로 해두는 기회가 되었네요. 본문을 수정해서 모델 튜닝을 넣었습니다.

    e1071 패키지의 메뉴얼인 http://cran.r-project.org/web/packages/e1071/e1071.pdf 에 설명이 있습니다. R 메뉴얼들은 사실 대체로 부실합니다… 그러니 tune.svm 을 구글에서 검색해보시고 다른 예제를 찾아보시는 것도 도움이 되실것입니다. 저도 그런 방법으로 배우고 있습니다.

    참고로 R에서는 모델 튜닝을 도와주는 caret(http://cran.r-project.org/web/packages/caret/index.html)이라는 좋은 패키지가 있습니다. 저도 몇번 써보지 못했는데, 알아두면 편할 듯 합니다.

    Posted 19 Jul 2012 at 2:16 pm
  3. lee tae woo wrote:

    정말 감사드립니다
    덕분에 많이 배우고 가네요…^^

    아 죄송한데 하나만 더 여쭙고 싶습니다..^^;;
    R에서 대용량 데이터일 경우 svm이 훈련시간이 상당히 오래 걸리는것 같더군요…
    obs가 20만건 정도 되는 자료인데요, 변수는 binary 타겟을 포함한 9개 정도입니다만…이게 R에서 돌아 갈까요? 시간이 얼마큼 걸리던 수행이 가능하다면 되도록 R에서 하고 싶은데요, 아니면 matlab이나 다른 패키지를 알아봐야 할까요…?

    Posted 20 Jul 2012 at 4:51 am
  4. Minkoo Seo wrote:

    요즘 저는 16만개 샘플, 60만개 Feature, 97개 class을 분석해보고 있는데, 워낙 feature가 많아서 아직 제대로 16만개 샘플을 처리하는 단계까지는 해보지 못했습니다.

    찾아보니 방법적인 면에서는 http://www.csie.ntu.edu.tw/~cjlin/talks/msra2.pdf 에 잘 정리되어 있는데요. 첫번째는 데이터를 다 분석하는게 아니라 서브 샘플만 보는 방법. 두번째는 kernel을 쓰지 않고 linear svm 을 사용하는 방법(http://cran.r-project.org/web/packages/LiblineaR/)등이 있는 듯 합니다.

    가장 유명한 라이브러리들이 SVM의 경우엔
    http://www.csie.ntu.edu.tw/~cjlin/libsvm/
    http://www.csie.ntu.edu.tw/~cjlin/liblinear/
    두가지가 있습니다.

    한쪽은 일반적인 svm을 커널까지 지원하는 구현이고 liblinear는 linear classification(linear svm포함)한 구현입니다.

    R의 e1071의 경우엔 libsvm기반이고 가장 많이 쓰이는 패키지입니다. (그외 패키지는 http://www.jstatsoft.org/v15/i09/paper 에 정리되어있습니다.) 제 생각엔 R을 쓴다면 http://cran.r-project.org/web/packages/LiblineaR/ 또는 e1071을 그대로 쓰되 probability는 계산하지 않고, kernel은 그냥 linear로 하는게 속도가 빠를 것 같습니다.

    파라미터 튜닝하는 부분은 caret패키지 등을 사용해서 병렬화 시킬 수 있습니다. CPU Core갯수가 2개 이상이면 속도향상이 있을 것입니다.

    마지막으로 다른 언어를 써보시는 것도 좋겠지만, learning curve가 있다고 한다면 저라면 그냥 R을 쓰겠습니다. e1071의 SVM은 타겟 class가 여러개이면 모든 class pair간 모델을 만듧니다. 그래서 9개 class라고 하시니 81개의 모델이 만들어집니다. 여기서 training데이터 크기가 n일때 O(n^3)의 시간이 들어갑니다. 각 클래스가 2만개(=20만/9)씩 균일하게 데이터가 있다고 하면 81개 모델 각각에 대해 2만^3의 시간이 들어갑니다. 그러니까 전체시간의 order는 81 * 2만^3 의 시간이란 것이죠. 여기서 MATLAB이 R대비 100배 빠르다고 가정하면 0.81 * 20만^3가 될것입니다. 따라서 속도차가 크지는 않을 것 같습니다.

    Posted 20 Jul 2012 at 3:22 pm
  5. lee tae woo wrote:

    와..의외로 몰랐던 부분을 많이 배우고 가네요 ^^

    정말 감사드립니다.

    Posted 21 Jul 2012 at 4:07 am

Post a Comment

Your email is never published nor shared. Required fields are marked *