투표(Votes) 기능 구현하기 1 - Models
polls/models.py
from django.contrib.auth.models import User
class Vote(models.Model):
question = models.ForeignKey(Question, on_delete=models.CASCADE)
choice = models.ForeignKey(Choice, on_delete=models.CASCADE)
voter = models.ForeignKey(User, on_delete=models.CASCADE)
class Meta:
constraints = [
models.UniqueConstraint(fields=['question', 'voter'], name='unique_voter_for_questions')
]
polls_api/serializers.py
class ChoiceSerializer(serializers.ModelSerializer):
votes_count = serializers.SerializerMethodField()
class Meta:
model = Choice
fields = ['choice_text', 'votes_count']
def get_votes_count(self, obj):
return obj.vote_set.count()
>>> from polls.models import *
>>> question = Question.objects.first()
>>> choice = question.choices.first()
>>> from django.contrib.auth.models import User
>>> user= User.objects.get(username='luke')
>>> Vote.objects.create(voter=user,question=question,choice=choice)
<Vote: Vote object (1)>
>>> question.id
1
투표(Votes) 기능 구현하기 2 - Serializers & Views
polls_api/serializers.py
from polls.models import Question,Choice, Vote
class VoteSerializer(serializers.ModelSerializer):
voter = serializers.ReadOnlyField(source='voter.username')
class Meta:
model = Vote
fields = ['id', 'question', 'choice', 'voter']
polls_api/views.py
from polls.models import Question,Choice, Vote
from polls_api.serializers import VoteSerializer
from .permissions import IsOwnerOrReadOnly , IsVoter
class VoteList(generics.ListCreateAPIView):
serializer_class = VoteSerializer
permission_classes = [permissions.IsAuthenticated]
def get_queryset(self, *args, **kwargs):
return Vote.objects.filter(voter=self.request.user)
def perform_create(self, serializer):
serializer.save(voter=self.request.user)
class VoteDetail(generics.RetrieveUpdateDestroyAPIView):
queryset = Vote.objects.all()
serializer_class = VoteSerializer
permission_classes = [permissions.IsAuthenticatedOrReadOnly, IsVoter]
polls_api/permissions.py
class IsVoterOrReadOnly(permissions.BasePermission):
def has_object_permission(self, request, view, obj):
return obj.voter == request.user
polls_api/urls.py
from django.urls import path, include
from .views import VoteList, VoteDetail
urlpatterns = [
...
path('vote/', VoteList.as_view()),
path('vote/<int:pk>/', VoteDetail.as_view()),
]
Validation
"Validation"은 에러를 방지합니다.
polls_api/serializers.py
from rest_framework.validators import UniqueTogetherValidator
class VoteSerializer(serializers.ModelSerializer):
def validate(self, attrs):
if attrs['choice'].question.id != attrs['question'].id:
raise serializers.ValidationError("Question과 Choice가 조합이 맞지 않습니다.")
return attrs
class Meta:
model = Vote
fields = ['id', 'question', 'choice', 'voter']
validators = [
UniqueTogetherValidator(
queryset=Vote.objects.all(),
fields=['question', 'voter']
)
]
polls_api/views.py
from rest_framework import status
from rest_framework.response import Response
class VoteList(generics.ListCreateAPIView):
serializer_class = VoteSerializer
permission_classes = [permissions.IsAuthenticated]
def get_queryset(self, *args, **kwargs):
return Vote.objects.filter(voter=self.request.user)
def create(self, request, *args, **kwargs):
new_data = request.data.copy()
new_data['voter'] = request.user.id
serializer = self.get_serializer(data=new_data)
serializer.is_valid(raise_exception=True)
self.perform_create(serializer)
headers = self.get_success_headers(serializer.data)
return Response(serializer.data, status=status.HTTP_201_CREATED, headers=headers)
class VoteDetail(generics.RetrieveUpdateDestroyAPIView):
queryset = Vote.objects.all()
serializer_class = VoteSerializer
permission_classes = [permissions.IsAuthenticated, IsVoter]
def perform_update(self, serializer):
serializer.save(voter=self.request.user)
Testing
Test를 이용하여 개발시간을 줄여주고 자동화할 수 있습니다.
polls_api/tests.py
from django.test import TestCase
from polls_api.serializers import QuestionSerializer
class QuestionSerializerTestCase(TestCase):
def test_with_valid_data(self):
serializer = QuestionSerializer(data={'question_text': 'abc'})
self.assertEqual(serializer.is_valid(), True)
new_question = serializer.save()
self.assertIsNotNone(new_question.id)
def test_with_invalid_data(self):
serializer = QuestionSerializer(data={'question_text': ''})
self.assertEqual(serializer.is_valid(), False)
python manage.py test
Testing Serializers
polls_api/tests.py
from django.test import TestCase
from polls_api.serializers import QuestionSerializer, VoteSerializer
from django.contrib.auth.models import User
from polls.models import Question, Choice, Vote
class VoteSerializerTest(TestCase):
def setUp(self):
self.user = User.objects.create(username='testuser')
self.question = Question.objects.create(
question_text='abc',
owner=self.user,
)
self.choice = Choice.objects.create(
question=self.question
choice_text='1'
)
def test_vote_serializer(self):
self.assertEqual(User.objects.all().count(), 1)
data = {
'question': self.question.id
'choice': self.choice.id
'voter': self.user.id
}
serializer = VoteSerializer(data=data)
self.assertTrue(serializer.is_valid())
vote = serializer.save()
self.assertEqual(vote.question, self.question)
self.assertEqual(vote.choice, self.choice)
self.assertEqual(vote.voter, self.user)
def test_vote_serializer_with_duplicate_vote(self):
self.assertEqual(User.objects.all().count, 1)
choice1 = Choice.objects.create(
quetsion=self.question,
choice_text='2'
)
Vote.objects.create(question=self.question, choice=self.choice, voter=self.user)
data = {
'question': self.question.id
'choice': self.choice.id
'voter': self.user.id
}
serializer = VoteSerializer(data=data)
self.assertTrue(serializer.is_valid())
def test_vote_serilaizer_with_unmatched_question_and_choice(self):
question2 = Question.objects.create(
question_text='abc',
owner=self.user,
)
choice2 = Choice.objects.create(
quetsion=question2,
choice_text='1'
)
data = {
'question': self.question.id
'choice': self.choice.id
'voter': self.user.id
}
serializer = VoteSerializer(data=data)
self.assertTrue(serializer.is_valid())
Testing Serializers
polls_api/tests.py
from rest_framework.test import APITestCase
from django.urls import reverse
from rest_framework import status
from django.utils import timezone
class QuestionListTest(APITestCase):
def setUp(self):
self.question_data = {'question_text': 'some question'}
self.url = reverse('queston-list')
def test_create_question(self):
user =User.objects.create(username='testuser', password='testpass')
self.client.force_authenticate(user=user)
response = self.client.post(self.url, self.question_data)
self.assertEqual(response.status_code, status.HTTP_201_CREATED)
self.assertEqual(Question.objects.count(), 1)
question = Question.objects.first()
self.assertEqual(question.question_text, self.question_data['question_text'])
self.assertEqual((timezone.now - question.pub_date).total_seconds(), 1)
def test_create_question_without_authentication(self):
response = self.client.post(self.url, self.question_data)
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
def test_list_question(self):
question = Question.objects.create(question_text='Question1')
choice = Choice.objects.create(question=question, choice_text='Question1')
Question.objects.create(question_text='Question2')
response = self.client.post(self.url)
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(len(response.data), 2)
self.assertEqual(response.data[0]['choices'][0]['choice_text'], choice.choice_text)
pip install coverage
coverage run manage.py test
coverage report
'데브코스 TIL > Django, API' 카테고리의 다른 글
Django REST Framework Part 2 (0) | 2023.11.06 |
---|---|
Django REST Framework Part 1 (0) | 2023.11.06 |
Django Tutorial Part 4 (0) | 2023.11.06 |
Django Tutorial Part 3 (0) | 2023.11.06 |
Django Tutorial Part 2 (0) | 2023.11.06 |