Longest common substring using Rolling Hash and Binary Search

Xin chào mọi người, đây là đề bài:


Còn đây là hướng dẫn:

Em đã dùng rolling hash để tính binary search tìm max length. Nó chạy đúng kết quả nhưng lại bị time limit exceed. Tuy dùng Binary search và rolling hash đã giảm đáng kể thời gian. nhưng mà em thấy rất khó để được time complexity O(|s|+|t|). Đó là chưa kể em mới tính có 1 cái hash thôi là đã nhức đầu rồi :)). Có phải em đã tiếp cận sai hướng hay không. Mong các anh đọc và gợi ý giúp em, em xin cảm ơn nhiều.

// Assignment3Problem5.cpp : This file contains the 'main' function. Program execution begins and ends there.
#include "pch.h"
#include <iostream>
#include <string>
#include <vector>
#include <algorithm>
long long prime1 = 1000000007, prime2 = 1000000009;
using namespace std;

struct Answer {
	size_t i, j, len;
};
vector<long long> powFunc(string s, int x, long long prime)
{
	vector<long long> result(s.length() + 1);
	result[0] = 1;
	for (int i = 1; i <= s.length(); i++)
	{
		result[i] = ((result[i - 1] % prime)*(x%prime) + prime) % prime;
	}
	return result;
}
vector<long long> hashValues(string s, long long prime, int x)
{
	vector<long long> result(s.length() + 1);
	result[0] = 0;
	for (int i = 1; i <= s.length(); i++)
	{
		result[i] = (x * result[i - 1] + s[i - 1]) % prime;
 	}
	return result;
}
bool isSubstring(string s, string t, int length, vector<long long> table1, vector<long long> table2, vector<long long> powTable1, vector<long long> powTable2)
{
	for (int i = length; i <= s.length(); i++)
	{
		long long hash1 = ((table1[i] - (powTable1[length] * table1[i - length])) % prime1 + prime1) % prime1;
		for (int j = length; j <= t.length(); j++)
		{
			long long hash2 = ((table2[j] - (powTable2[length] * table2[j - length])) % prime1 + prime1)%prime1;
			if (hash1 == hash2) return true;
		}
	}
	return false;
}
int binarySearch(string s, string t, int left, int right, vector<long long> table1, vector<long long> table2, vector<long long> powerFun1, vector<long long> powerFun2)
{
	if (left > right) return left - 1;
	int mid = (left + right) / 2;
	if (isSubstring(s, t, mid, table1, table2, powerFun1, powerFun2) == true)
		return binarySearch(s, t, mid + 1, right, table1, table2, powerFun1, powerFun2);
	return binarySearch(s, t, left, mid - 1, table1, table2, powerFun1, powerFun2);
}
Answer solve(const string &s, const string &t) {
	Answer ans = { 0, 0, 0 };
	/*
	for (size_t i = 0; i < s.size(); i++)
		for (size_t j = 0; j < t.size(); j++)
			for (size_t len = 0; i + len <= s.size() && j + len <= t.size(); len++)
				if (len > ans.len && s.substr(i, len) == t.substr(j, len))
					ans = { i, j, len };
	*/
	int x = 3;
	vector<long long> hashTable1 = hashValues(s, prime1, x);
	vector<long long> hashTable2 = hashValues(t, prime1, x);
	vector<long long> powTable1 = powFunc(s, x, prime1);
	vector<long long> powTable2 = powFunc(t, x, prime1);
	int maxLength = binarySearch(s, t, 0, min(s.length(), t.length()), hashTable1, hashTable2, powTable1, powTable2);
	
	for (int i = maxLength; i <= s.length(); i++)
	{
		long long hash1 = ((hashTable1[i] - powTable1[maxLength] * hashTable1[i - maxLength]) % prime1 + prime1) % prime1;
		for (int j = maxLength; j <= t.length(); j++)
		{
			long long hash2 = ((hashTable2[j] - powTable2[maxLength] * hashTable2[j - maxLength]) % prime1 + prime1) % prime1;
			if (hash1 == hash2)
			{
				ans.i = i - maxLength;
				ans.j = j - maxLength;
				ans.len = maxLength;
				return ans;
			}
		}
	}
	ans.i = 0;
	ans.j = 0;
	ans.len = 0;
	return ans;
}
int main() {
	ios_base::sync_with_stdio(false), cin.tie(0);
	string s, t;
	while (cin >> s >> t) {
		auto ans = solve(s, t);
		cout << ans.i << " " << ans.j << " " << ans.len << "\n";
	}
}

Bài này QHĐ là O(|s||t|) rồi, tính hash chi nữa.

6 Likes


Ko được đâu ạ, nó cũng có nói nè :v

Có vẻ bạn hiểu sai cách tính rồi:

  • isSubString của bạn độ phức tạp là O(n^2)
  • ý tác giả là tính precompute hash cho chuỗi 1 rồi dùng hash table để search chuỗi 2 (tương đương hàm find_common ở dưới): độ phức tạp O(n)

mod = 10**9+7
base = 31

pw = [pow(base, i, mod) for i in range(10**5+5)]

def precompute_hash(s):
    h = [0]
    for i in range(len(s)):
        v = ((ord(s[i]) - ord('a') + 1) + h[-1] * base) % mod
        h.append(v)
    return h

def hash(h, i, n):
    return ((h[i+n] - h[i]*pw[n]) % mod + mod) % mod

def is_substring(hs, i, ht, j, n):
    return hash(hs,i,n) == hash(ht,j,n)

def find_common(hs, ht, k):
    ns = len(hs)
    nt = len(ht)

    d = {}
    for i in range(ns-k):
        h = hash(hs, i, k)

        d[h] = i

    for j in range(nt - k):
        h = hash(ht, j, k)
        if h in d:
            return (d[h],j)
    return (-1,-1)


def longest_common(ha, hb, minLen, maxLen):
    print(minLen, maxLen)
    if minLen == maxLen:
        return (find_common(ha, hb, minLen), minLen)

    segLen = (maxLen - minLen + 1)
    half = segLen // 2
    
    res = find_common(ha, hb, minLen + half)
    if res[0] == -1:
        return longest_common(ha, hb, minLen, minLen + half - 1)
    else:
        return longest_common(ha, hb, minLen + half, maxLen)

a = "banana"
b = "nan"

ha = precompute_hash(a)
hb = precompute_hash(b)

(i, j), k = longest_common(ha, hb, 0, min(len(a), len(b)))
print(i, j, k, a[i:i+k], b[j:j+k])

7 Likes

em cảm ơn về sự góp ý của anh. Nhưng không biết em có hiểu sai không. Trong hàm find_common, anh đã tạo một array có kích thước tương đương với giá trị của max(hashvalues_s). Điều này là rất tốn kém, vì nếu hashvalue này có max lên ~ 1 tỷ thì nó sẽ tạo array có kích thước 1 tỷ chỉ để lưu 1 vài hash value. Hơn nữa, quả thật hàm find_common có time complexity là O(n), nhưng sau khi chạy qua hàm binary search thì nó là O(n*log(n)) :v

1 Like
83% thành viên diễn đàn không hỏi bài tập, còn bạn thì sao?