piątek, 23 sierpnia 2013

14783. Dobra i zła energia [AL_06_05]

Zadanie:
https://pl.spoj.com/problems/AL_06_05

Skrócony opis problemu:
Dla macierzy $n$x$n$ należy znaleźć największą sumę dowolnej podmacierzy.

Uwaga: Warto zapoznać się najpierw z wersją jednowymiarową opisaną tutaj. Będę na niej bazował. Problem dla wersji jednowymiarowej to znalezienie podłańcucha o największej sumie.



Rozwiązanie naiwne:
Możemy, tak samo jak w wersji jednowymiarowej, dla każdej czwórki reprezentującej podmacierz obliczyć jej sumę i wypisać największą.

Złożoność: $O(n^6)$, bo takich czwórek jest dokładnie $\left(\frac{n(n+1)}{2}\right)^2 = \frac{n^4+2n^3+n^2}{4}$, a dla każdej z nich zliczamy sumę macierzy, która pesymistycznie ma rozmiar $n^2$.

Algorytm:
int main()
{
    int n, a, b, c, d, e, f, s, max = -1000000, x, in[n+1][n+1];

    get(n);
    for(a = 0; a < n; ++a)
        for(b = 0; b < n; ++b)
            get(in[a][b]);
    for(a = 0; a < n; ++a)
        for(b = 0; b < n; ++b)
            for(c = 0; c < n; ++c)
                for(d = 0; d < n; ++d)
                {
                    s = x = 0;
                    /* x jest wartownikiem; jeśli po zakończeniu pętli x=0 to znaczy, że mieliśmy macierz z kolumnami np. od 5 do 3 zamiast od 3 do 5 i nie policzyliśmy jej wcale;
                    bez niego dla takiej macierzy mielibyśmy sumę 0, a dla macierzy ze wszystkimi wartościami ujemnymi nie możemy wpisać zera;
                    poza tym dzięki temu, że pomijamy te macierze o odwrotnych współrzędnych, to przechodzimy "tylko" (n^4+2n^3+n^2)/4 macierzy, a nie n^4 */
                    for(e = a; e <= b; ++e)
                        for(f = c; f <= d; x = ++f)
                            s += in[e][f];
                    if(x && s > max) // jeśli przeglądnęliśmy macierz
                        max = s;
                }
    print(max);
}

Rozwiązanie sprytniejsze:
Możemy sprawdzić wszystkie podmacierze opierając się na wcześniejszych wynikach. Czyli dla każdej komórki $tab[i][j]$ idziemy na maksa w prawo i w dół, i liczymy kolejne sumy, biorąc sumę wszystkich elementów $s$ zaczynając od w danym rzędzie $k$ od $i$ idąc w prawo, i każdej komórce $tab[i][l]$ przypisujemy wartość $s+tab[i-1][l]$. Dzięki temu będziemy sumę całej podmacierzy od $tab[i][j]$ do $tab[k][l]$ dla każdych wartości $i, j, k, l$.
Macierzy $tab$ nie musimy nawet zerować, gdyż zawsze przypisujemy jej wartość bazując na $s$ i poprzednim wierszu. Tylko pierwszy (o indeksie 0) wiersz i kolumna muszą być wyzerowane. Musimy zerować też $s$.
Jeśli ktoś będzie miał trudności ze zrozumieniem poniższego kodu, można zastąpić $tab[h][w]$ - $tab[i][j][h][w]$. Możemy zrezygnować z tych 2 wymiarów, gdyż jeśli dla jakiejś pary $(i;j)$ obliczymy już wszystkie podmacierze, to nie będziemy potrzebować tych wartości, a zerować macierzy nie trzeba (jak wspomniałem wyżej), więc możemy zastępować je nowymi parami. Można się posunąć jeszcze dalej w optymalizacji złożoności pamięciowej i przechowywać tylko 2 ostatnie wiersze macierzy $tab$ i korzystać z nich na przemian.

Złożoność: $O(n^4)$, gdyż dla każdej z $n^2$ komórek przechodzimy przez prostokąty i wszystkich możliwych szerokościach i długościach, których jest pesymistycznie $n^2$. Warto też zauważyć, że złożoność pamięciowa wynosi tutaj $O(n^2)$, w przeciwieństwie, to $O(1)$ w powyższym algorytmie.

Algorytm:
int main()
{
    int t, n, i, j, w, h, s, max = -1000000, in[n][n], tab[n+1][n+1];

    get(n);
    for(i = 0; i < n; ++i)
        for(j = 0; j < n; ++j)
            get(in[i][j]);
    for(i = 0; i < n; ++i)
        for(j = 0; j < n; ++j)
            for(h = 1; h <= n - i; ++h) // h to długość podmacierzy
                for(s = 0, w = 1; w <= n - j; ++w) // w to szerokość podmacierzy
                {
                    s += in[i+h-1][j+w-1];
                    tab[h][w] = s+tab[h-1][w];
                    if(tab[h][w] > max)
                        max = tab[h][w];
                }
    print(max);
}

Rozwiązanie najsprytniejsze:
Możemy skorzystać z sum prefiksowych. Obliczamy dla każdej pary $(i;j) \wedge i \le j$ obliczamy dla każdej kolumny sumę elementów od wiersza $i$ do wiersza $j$. Takich par będzie nawiasem mówiąc maksymalnie $\frac{n(n+1)}{2}$.
Np. dla $n=3$ i macierzy:
11 -2  3
-7  0 -6
 2  8  3
Będziemy mieli takie sumy:
$$S(0,0) = [11, -2, 3] \\ S(0,1) = [4, -2, -3] \\ S(0,2) = [6, 6, 0] \\ S(1,1) = [-7, 0, -6] \\ S(1,2) = [-5, 8, -3] \\ S(2,2) = [2, 8, 3]$$
Co z nimi teraz? Zauważmy, że jeśli weźmiemy sobie kolumny $k$ i $l$ i zliczymy wszystkie liczby od $k$ do $l$ w $S(i,j)$ to będziemy mieli sumę w prostokącie ograniczonym z góry przez $i$, z dołu przez $j$, z lewej przez $k$ i z prawej przez $l$. No tak, ale to byśmy musieli sprawdzać wszystkie pary $(k;l) \wedge k \le l$ dla każdego z $\frac{n(n+1)}{2}$ par $(i;j)$, ergo znów będziemy mieli złożoność $O(n^4)$. Tak, ale nie musimy sprawdzać każdej pary $(k;l)$. Możemy na każdym wierszu pary $(i;j)$ zapuścić omawiany tutaj algorytm liniowy dla wariantu jednowymiarowego. Czemu? Ano temu, że te wiersze już są sumami po wierszach macierzy wejściowej. Pozostanie nam tylko znaleźć dla każdego optymalne $k$ i $l$, a to możemy zrobić w czasie liniowym. Tak naprawdę, to nawet nas nie obchodzą $k$ i $l$, bo od razu dostaniemy optymalną sumę między nimi.

Złożoność: $O(n^3)$, gdyż tworzymy $\frac{n(n+1)}{2}$ wierszy liniowo (każdy element wiersza w czasie stałym za pomocą sum prefiksowych, czyli bazując na wartościach poprzednich elementów), a następnie dla każdego wiersza również liniowo znajdujemy największą sumę w nim. Jako że każdy wiersz jest reprezentantem ciągu wierszy macierzy z wejścia, to przeszukamy wszystkie możliwe podmacierze.
Złożoność pamięciowa wzrosła tu nam do $O(n^3)$.

Ostateczny algorytm:
int main()
{
    int n, i, j, k, s, s1, s2 = -1000000, tab[n+1][n+1][n+1], in[n+1][n+1];

    get(n);
    for(i = 1; i <= n; ++i)
        for(j = 1; j <= n; ++j)
            get(in[i][j]);
    for(i = 1; i <= n; ++i) // dla każdego i
    {
        for(j = i; j <= n; ++j) // i każdego j
            for(k = 1; k <= n; ++k)
                tab[i][j][k] = tab[i][j-1][k]+in[i][k]; // tworzymy sumę bazując na poprzednim wierszu
        for(j = i-1; j > 0; --j) // musimy uwzględnić j również w poprzednich sumach
            for(k = 1; k <= n; ++k)
                tab[j][i][k] = tab[j][i-1][k]+in[i][k];
    }
    for(i = 1; i <= n; ++i) // dla każdej pary (i;j)
        for(j = i; j <= n; ++j)
        {
            s = 0; s1 = -1000000; // szukamy w wierszu S(i,j) największej sumy
            for(k = 1; k <= n; ++k)
            {
                if(s > 0)
                    s += tab[i][j][k];
                else
                    s = tab[i][j][k];
                if(s > s1)
                    s1 = s;
            }
            if(s1 > s2)
                s2 = s1;
        }
    print(s2);
}

Optymalizacja:
Możemy zmniejszyć złożoność pamięciową do kwadratowej oraz przyspieszyć program czy zachowaniu takiej samej złożoności czasowej, lepiej wykorzystując sumy prefiksowe.
Mianowicie, jeśli mamy sobie wiersz: $S(4,5)$ to wystarczy, że obliczymy tylko (dla $n=6$) wiersze: $S(0,0), S(0,1), S(0,2), S(0,3), S(0,4), S(0,5)$. Wiersz $S(4,5)$ otrzymamy odejmując $S(0,3)$ od $S(0,5)$. Musimy więc obliczyć jedynie $n$ wierszy $S(0;i)$ w czasie liniowym. Pierwszy etap zajmie nam zatem jedynie $O(n^2)$ czasu zamiast $O(n^3)$. Zaoszczędzimy również na pamięci, gdyż ilość wierszy zmniejszyła nam się z kwadratowej na liniową, a musieliśmy wszystkie pamiętać.
W głównej części programu korzystamy po prostu zamiast $tab[i][j][k]$ z $tab[j][k]-tab[i-1][k]$ i tyle.

Możemy też optymalizować przypadki szczególne, jak macierz złożona z samych liczb ujemnych, z samych liczb nieujemnych lub z jedną liczbą nieujemną. Złożoność dla każdego z tych przypadków to $O(n^2)$, bo musimy tylko wczytać dane. Nie będę jednak ich uwzględniał w poniższym kodzie.

Ostateczny, ulepszony algorytm:
int main()
{
    int n, i, j, k, s, s1, s2 = -1000000, tab[n+1][n+1];

    get(n);
    for(i = 1; i <= n; ++i)
        for(j = 1; j <= n; ++j)
        {
            get(k); // możemy na bieżąco obliczać sumy prefiksowe; nie musimy zapamiętywać macierzy z wejścia, gdyż nie będzie nam potem już do niczego potrzebna
            tab[i][j] = k+tab[i-1][j];
        }
    for(i = 1; i <= n; ++i)
        for(j = i; j <= n; ++j)
        {
            s = 0; s1 = -1000000;
            for(k = 1; k <= n; ++k)
            {
                if(s > 0)
                    s += tab[j][k]-tab[i-1][k];
                else
                    s = tab[j][k]-tab[i-1][k];
                if(s > s1)
                    s1 = s;
            }
            if(s1 > s2)
                s2 = s1;
        }
    print(s2);
}

Brak komentarzy:

Prześlij komentarz