Trong video này Đạt giới thiệu thuật toán QuickSort và code bằng Python/C, hướng dẫn cơ bản về thuật toán.
Code Python
def partition(arr, lo, hi):
pivot = arr[hi]
small_i = lo - 1
for i in range(lo, hi): # chay tu low den high (khong tinh high)
if arr[i] <= pivot:
small_i += 1
arr[small_i], arr[i] = arr[i], arr[small_i]
small_i += 1
arr[small_i], arr[hi] = arr[hi], arr[small_i]
return small_i
def quicksort(arr, lo, hi):
if lo < hi:
pivot = partition(arr, lo, hi)
quicksort(arr, lo, pivot - 1)
quicksort(arr, pivot + 1, hi)
input_arr = [10, 80, 30, 90, 40, 50, 70]
print "input array {}\n".format(input_arr)
quicksort(input_arr, 0, len(input_arr) - 1)
print "sorted array ", input_arr
Code C
#include "stdio.h"
void swap(int* a, int* b) {
int t = *a;
*a = *b;
*b = t;
}
void print_array(int *arr, int size) {
int i = 0;
for(; i < size; ++i){
printf("%d ", arr[i]);
}
printf("\n");
}
int partition(int *arr, int lo, int hi) {
int pivot = arr[hi];
int si = lo - 1;
int j;
for(j = lo; j <= hi - 1; ++j) {
if (arr[j] <= pivot) {
si++;
swap(&arr[si], &arr[j]);
}
}
si++;
swap(&arr[si], &arr[hi]);
return si;
}
void quicksort(int *arr, int lo, int hi) {
if (lo < hi) {
int pivot = partition(arr, lo, hi);
quicksort(arr, lo, pivot - 1);
quicksort(arr, pivot + 1, hi);
}
}
int main() {
int arr[] = {10, 80, 30, 90, 40, 50, 70};
int n = sizeof(arr)/sizeof(arr[0]);
printf("before ");
print_array(arr, n);
quicksort(arr, 0, n-1);
printf("after ");
print_array(arr, n);
}
Code Python with log
def partition(arr, lo, hi):
pivot = arr[hi]
print "new pivot [{}], lo: {}, hi: {}".format(pivot, lo, hi)
small_i = lo - 1
for i in range(lo, hi):
if arr[i] <= pivot:
small_i += 1
print_quicksort(arr, hi, small_i, i, lo, hi)
arr[small_i], arr[i] = arr[i], arr[small_i]
print "after swap: {}\n".format(arr)
small_i += 1
arr[small_i], arr[hi] = arr[hi], arr[small_i]
return small_i
def quicksort(arr, lo, hi):
if lo < hi:
pivot = partition(arr, lo, hi)
quicksort(arr, lo, pivot - 1)
quicksort(arr, pivot + 1, hi)
def print_quicksort(arr, pivot_index, small_i, index, lo, hi):
print "lo: {}, hi: {}, small index: {}, index: {}".format(lo, hi, small_i, index)
output = ""
for i, value in enumerate(arr):
if i == lo:
output += "("
if i == hi:
output += ")"
if i == pivot_index:
output += "[{}] ".format(value)
elif i == small_i:
output += "{}-> ".format(value)
elif i == index:
output += "<-{} ".format(value)
else:
output += "{} ".format(value)
print output
input_arr = [10, 80, 30, 90, 40, 50, 70]
print "input array {}\n".format(input_arr)
quicksort(input_arr, 0, len(input_arr) - 1)
print "sorted array ", input_arr