Processing math: 100%

piątek, 23 sierpnia 2013

14783. Dobra i zła energia [AL_06_05]

Zadanie:
https://pl.spoj.com/problems/AL_06_05

Skrócony opis problemu:
Dla macierzy nxn należy znaleźć największą sumę dowolnej podmacierzy.

Uwaga: Warto zapoznać się najpierw z wersją jednowymiarową opisaną tutaj. Będę na niej bazował. Problem dla wersji jednowymiarowej to znalezienie podłańcucha o największej sumie.



Rozwiązanie naiwne:
Możemy, tak samo jak w wersji jednowymiarowej, dla każdej czwórki reprezentującej podmacierz obliczyć jej sumę i wypisać największą.

Złożoność: O(n^6), bo takich czwórek jest dokładnie \left(\frac{n(n+1)}{2}\right)^2 = \frac{n^4+2n^3+n^2}{4}, a dla każdej z nich zliczamy sumę macierzy, która pesymistycznie ma rozmiar n^2.

Algorytm:
int main()
{
    int n, a, b, c, d, e, f, s, max = -1000000, x, in[n+1][n+1];

    get(n);
    for(a = 0; a < n; ++a)
        for(b = 0; b < n; ++b)
            get(in[a][b]);
    for(a = 0; a < n; ++a)
        for(b = 0; b < n; ++b)
            for(c = 0; c < n; ++c)
                for(d = 0; d < n; ++d)
                {
                    s = x = 0;
                    /* x jest wartownikiem; jeśli po zakończeniu pętli x=0 to znaczy, że mieliśmy macierz z kolumnami np. od 5 do 3 zamiast od 3 do 5 i nie policzyliśmy jej wcale;
                    bez niego dla takiej macierzy mielibyśmy sumę 0, a dla macierzy ze wszystkimi wartościami ujemnymi nie możemy wpisać zera;
                    poza tym dzięki temu, że pomijamy te macierze o odwrotnych współrzędnych, to przechodzimy "tylko" (n^4+2n^3+n^2)/4 macierzy, a nie n^4 */
                    for(e = a; e <= b; ++e)
                        for(f = c; f <= d; x = ++f)
                            s += in[e][f];
                    if(x && s > max) // jeśli przeglądnęliśmy macierz
                        max = s;
                }
    print(max);
}

Rozwiązanie sprytniejsze:
Możemy sprawdzić wszystkie podmacierze opierając się na wcześniejszych wynikach. Czyli dla każdej komórki tab[i][j] idziemy na maksa w prawo i w dół, i liczymy kolejne sumy, biorąc sumę wszystkich elementów s zaczynając od w danym rzędzie k od i idąc w prawo, i każdej komórce tab[i][l] przypisujemy wartość s+tab[i-1][l]. Dzięki temu będziemy sumę całej podmacierzy od tab[i][j] do tab[k][l] dla każdych wartości i, j, k, l.
Macierzy tab nie musimy nawet zerować, gdyż zawsze przypisujemy jej wartość bazując na s i poprzednim wierszu. Tylko pierwszy (o indeksie 0) wiersz i kolumna muszą być wyzerowane. Musimy zerować też s.
Jeśli ktoś będzie miał trudności ze zrozumieniem poniższego kodu, można zastąpić tab[h][w] - tab[i][j][h][w]. Możemy zrezygnować z tych 2 wymiarów, gdyż jeśli dla jakiejś pary (i;j) obliczymy już wszystkie podmacierze, to nie będziemy potrzebować tych wartości, a zerować macierzy nie trzeba (jak wspomniałem wyżej), więc możemy zastępować je nowymi parami. Można się posunąć jeszcze dalej w optymalizacji złożoności pamięciowej i przechowywać tylko 2 ostatnie wiersze macierzy tab i korzystać z nich na przemian.

Złożoność: O(n^4), gdyż dla każdej z n^2 komórek przechodzimy przez prostokąty i wszystkich możliwych szerokościach i długościach, których jest pesymistycznie n^2. Warto też zauważyć, że złożoność pamięciowa wynosi tutaj O(n^2), w przeciwieństwie, to O(1) w powyższym algorytmie.

Algorytm:
int main()
{
    int t, n, i, j, w, h, s, max = -1000000, in[n][n], tab[n+1][n+1];

    get(n);
    for(i = 0; i < n; ++i)
        for(j = 0; j < n; ++j)
            get(in[i][j]);
    for(i = 0; i < n; ++i)
        for(j = 0; j < n; ++j)
            for(h = 1; h <= n - i; ++h) // h to długość podmacierzy
                for(s = 0, w = 1; w <= n - j; ++w) // w to szerokość podmacierzy
                {
                    s += in[i+h-1][j+w-1];
                    tab[h][w] = s+tab[h-1][w];
                    if(tab[h][w] > max)
                        max = tab[h][w];
                }
    print(max);
}

Rozwiązanie najsprytniejsze:
Możemy skorzystać z sum prefiksowych. Obliczamy dla każdej pary (i;j) \wedge i \le j obliczamy dla każdej kolumny sumę elementów od wiersza i do wiersza j. Takich par będzie nawiasem mówiąc maksymalnie \frac{n(n+1)}{2}.
Np. dla n=3 i macierzy:
11 -2  3
-7  0 -6
 2  8  3
Będziemy mieli takie sumy:
S(0,0) = [11, -2, 3] \\ S(0,1) = [4, -2, -3] \\ S(0,2) = [6, 6, 0] \\ S(1,1) = [-7, 0, -6] \\ S(1,2) = [-5, 8, -3] \\ S(2,2) = [2, 8, 3]
Co z nimi teraz? Zauważmy, że jeśli weźmiemy sobie kolumny k i l i zliczymy wszystkie liczby od k do l w S(i,j) to będziemy mieli sumę w prostokącie ograniczonym z góry przez i, z dołu przez j, z lewej przez k i z prawej przez l. No tak, ale to byśmy musieli sprawdzać wszystkie pary (k;l) \wedge k \le l dla każdego z \frac{n(n+1)}{2} par (i;j), ergo znów będziemy mieli złożoność O(n^4). Tak, ale nie musimy sprawdzać każdej pary (k;l). Możemy na każdym wierszu pary (i;j) zapuścić omawiany tutaj algorytm liniowy dla wariantu jednowymiarowego. Czemu? Ano temu, że te wiersze już są sumami po wierszach macierzy wejściowej. Pozostanie nam tylko znaleźć dla każdego optymalne k i l, a to możemy zrobić w czasie liniowym. Tak naprawdę, to nawet nas nie obchodzą k i l, bo od razu dostaniemy optymalną sumę między nimi.

Złożoność: O(n^3), gdyż tworzymy \frac{n(n+1)}{2} wierszy liniowo (każdy element wiersza w czasie stałym za pomocą sum prefiksowych, czyli bazując na wartościach poprzednich elementów), a następnie dla każdego wiersza również liniowo znajdujemy największą sumę w nim. Jako że każdy wiersz jest reprezentantem ciągu wierszy macierzy z wejścia, to przeszukamy wszystkie możliwe podmacierze.
Złożoność pamięciowa wzrosła tu nam do O(n^3).

Ostateczny algorytm:
int main()
{
    int n, i, j, k, s, s1, s2 = -1000000, tab[n+1][n+1][n+1], in[n+1][n+1];

    get(n);
    for(i = 1; i <= n; ++i)
        for(j = 1; j <= n; ++j)
            get(in[i][j]);
    for(i = 1; i <= n; ++i) // dla każdego i
    {
        for(j = i; j <= n; ++j) // i każdego j
            for(k = 1; k <= n; ++k)
                tab[i][j][k] = tab[i][j-1][k]+in[i][k]; // tworzymy sumę bazując na poprzednim wierszu
        for(j = i-1; j > 0; --j) // musimy uwzględnić j również w poprzednich sumach
            for(k = 1; k <= n; ++k)
                tab[j][i][k] = tab[j][i-1][k]+in[i][k];
    }
    for(i = 1; i <= n; ++i) // dla każdej pary (i;j)
        for(j = i; j <= n; ++j)
        {
            s = 0; s1 = -1000000; // szukamy w wierszu S(i,j) największej sumy
            for(k = 1; k <= n; ++k)
            {
                if(s > 0)
                    s += tab[i][j][k];
                else
                    s = tab[i][j][k];
                if(s > s1)
                    s1 = s;
            }
            if(s1 > s2)
                s2 = s1;
        }
    print(s2);
}

Optymalizacja:
Możemy zmniejszyć złożoność pamięciową do kwadratowej oraz przyspieszyć program czy zachowaniu takiej samej złożoności czasowej, lepiej wykorzystując sumy prefiksowe.
Mianowicie, jeśli mamy sobie wiersz: S(4,5) to wystarczy, że obliczymy tylko (dla n=6) wiersze: S(0,0), S(0,1), S(0,2), S(0,3), S(0,4), S(0,5). Wiersz S(4,5) otrzymamy odejmując S(0,3) od S(0,5). Musimy więc obliczyć jedynie n wierszy S(0;i) w czasie liniowym. Pierwszy etap zajmie nam zatem jedynie O(n^2) czasu zamiast O(n^3). Zaoszczędzimy również na pamięci, gdyż ilość wierszy zmniejszyła nam się z kwadratowej na liniową, a musieliśmy wszystkie pamiętać.
W głównej części programu korzystamy po prostu zamiast tab[i][j][k] z tab[j][k]-tab[i-1][k] i tyle.

Możemy też optymalizować przypadki szczególne, jak macierz złożona z samych liczb ujemnych, z samych liczb nieujemnych lub z jedną liczbą nieujemną. Złożoność dla każdego z tych przypadków to O(n^2), bo musimy tylko wczytać dane. Nie będę jednak ich uwzględniał w poniższym kodzie.

Ostateczny, ulepszony algorytm:
int main()
{
    int n, i, j, k, s, s1, s2 = -1000000, tab[n+1][n+1];

    get(n);
    for(i = 1; i <= n; ++i)
        for(j = 1; j <= n; ++j)
        {
            get(k); // możemy na bieżąco obliczać sumy prefiksowe; nie musimy zapamiętywać macierzy z wejścia, gdyż nie będzie nam potem już do niczego potrzebna
            tab[i][j] = k+tab[i-1][j];
        }
    for(i = 1; i <= n; ++i)
        for(j = i; j <= n; ++j)
        {
            s = 0; s1 = -1000000;
            for(k = 1; k <= n; ++k)
            {
                if(s > 0)
                    s += tab[j][k]-tab[i-1][k];
                else
                    s = tab[j][k]-tab[i-1][k];
                if(s > s1)
                    s1 = s;
            }
            if(s1 > s2)
                s2 = s1;
        }
    print(s2);
}

Brak komentarzy:

Prześlij komentarz