package main import ( "encoding/json" "net/http" "net/http/httptest" "path/filepath" "strings" "testing" ) func newTestHandler(t *testing.T) *Handler { t.Helper() store, err := NewStore(filepath.Join(t.TempDir(), "test.db")) if err != nil { t.Fatalf("NewStore() failed: %v", err) } t.Cleanup(func() { store.Close() }) return NewHandler(store, "http://xieyaxin.top:8899", "test-token") } func decodeBody(t *testing.T, body string) map[string]string { t.Helper() var resp map[string]string if err := json.NewDecoder(strings.NewReader(body)).Decode(&resp); err != nil { t.Fatalf("failed to decode response: %v", err) } return resp } func TestHandleCreate_Valid(t *testing.T) { h := newTestHandler(t) body := `{"url":"https://example.com"}` w := httptest.NewRecorder() r := httptest.NewRequest(http.MethodPost, "/api/shorten", strings.NewReader(body)) r.Header.Set("Content-Type", "application/json") h.HandleCreate(w, r) if w.Code != http.StatusOK { t.Fatalf("expected status 200, got %d: %s", w.Code, w.Body.String()) } resp := decodeBody(t, w.Body.String()) if len(resp["code"]) != 7 { t.Fatalf("expected code length 7, got %d: %q", len(resp["code"]), resp["code"]) } if resp["short_url"] != "http://xieyaxin.top:8899/"+resp["code"] { t.Fatalf("expected short_url %q, got %q", "http://xieyaxin.top:8899/"+resp["code"], resp["short_url"]) } if resp["original_url"] != "https://example.com" { t.Fatalf("expected original_url %q, got %q", "https://example.com", resp["original_url"]) } } func TestHandleCreate_Consecutive(t *testing.T) { h := newTestHandler(t) body1 := `{"url":"https://example.com/first"}` w1 := httptest.NewRecorder() r1 := httptest.NewRequest(http.MethodPost, "/api/shorten", strings.NewReader(body1)) r1.Header.Set("Content-Type", "application/json") h.HandleCreate(w1, r1) body2 := `{"url":"https://example.com/second"}` w2 := httptest.NewRecorder() r2 := httptest.NewRequest(http.MethodPost, "/api/shorten", strings.NewReader(body2)) r2.Header.Set("Content-Type", "application/json") h.HandleCreate(w2, r2) resp1 := decodeBody(t, w1.Body.String()) resp2 := decodeBody(t, w2.Body.String()) if len(resp1["code"]) != 7 { t.Fatalf("expected code length 7, got %d: %q", len(resp1["code"]), resp1["code"]) } if len(resp2["code"]) != 7 { t.Fatalf("expected code length 7, got %d: %q", len(resp2["code"]), resp2["code"]) } if resp1["code"] == resp2["code"] { t.Fatal("consecutive short links should have different codes") } } func TestHandleCreate_MissingURL(t *testing.T) { h := newTestHandler(t) w := httptest.NewRecorder() r := httptest.NewRequest(http.MethodPost, "/api/shorten", strings.NewReader(`{}`)) r.Header.Set("Content-Type", "application/json") h.HandleCreate(w, r) if w.Code != http.StatusBadRequest { t.Fatalf("expected status 400, got %d: %s", w.Code, w.Body.String()) } resp := decodeBody(t, w.Body.String()) if resp["error"] == "" { t.Fatal("expected error message in response") } } func TestHandleCreate_EmptyURL(t *testing.T) { h := newTestHandler(t) w := httptest.NewRecorder() r := httptest.NewRequest(http.MethodPost, "/api/shorten", strings.NewReader(`{"url":""}`)) r.Header.Set("Content-Type", "application/json") h.HandleCreate(w, r) if w.Code != http.StatusBadRequest { t.Fatalf("expected status 400, got %d", w.Code) } } func TestHandleCreate_InvalidScheme(t *testing.T) { h := newTestHandler(t) w := httptest.NewRecorder() r := httptest.NewRequest(http.MethodPost, "/api/shorten", strings.NewReader(`{"url":"ftp://example.com"}`)) r.Header.Set("Content-Type", "application/json") h.HandleCreate(w, r) if w.Code != http.StatusBadRequest { t.Fatalf("expected status 400 for unsupported scheme, got %d", w.Code) } } func TestHandleCreate_WrongMethod(t *testing.T) { h := newTestHandler(t) w := httptest.NewRecorder() r := httptest.NewRequest(http.MethodGet, "/api/shorten", nil) h.HandleCreate(w, r) // HandleCreate itself does not check method (mux handles routing) if w.Code != http.StatusBadRequest { t.Fatalf("expected status 400 for GET request (no body), got %d", w.Code) } } func TestHandleCreate_InvalidBody(t *testing.T) { h := newTestHandler(t) w := httptest.NewRecorder() r := httptest.NewRequest(http.MethodPost, "/api/shorten", strings.NewReader(`not json`)) r.Header.Set("Content-Type", "application/json") h.HandleCreate(w, r) if w.Code != http.StatusBadRequest { t.Fatalf("expected status 400 for invalid JSON, got %d", w.Code) } } func TestHandleRedirect_Valid(t *testing.T) { h := newTestHandler(t) // Create a short link first body := `{"url":"https://example.com"}` w := httptest.NewRecorder() r := httptest.NewRequest(http.MethodPost, "/api/shorten", strings.NewReader(body)) r.Header.Set("Content-Type", "application/json") h.HandleCreate(w, r) resp := decodeBody(t, w.Body.String()) code := resp["code"] // Test redirect w = httptest.NewRecorder() r = httptest.NewRequest(http.MethodGet, "/"+code, nil) r.SetPathValue("code", code) h.HandleRedirect(w, r) if w.Code != http.StatusFound { t.Fatalf("expected status 302 for valid code, got %d", w.Code) } if loc := w.Header().Get("Location"); loc != "https://example.com" { t.Fatalf("expected Location header %q, got %q", "https://example.com", loc) } } func TestHandleRedirect_NotFound(t *testing.T) { h := newTestHandler(t) w := httptest.NewRecorder() r := httptest.NewRequest(http.MethodGet, "/nonexist", nil) r.SetPathValue("code", "nonexist") h.HandleRedirect(w, r) if w.Code != http.StatusNotFound { t.Fatalf("expected status 404 for non-existent code, got %d", w.Code) } } func TestHandleRedirect_EmptyCode(t *testing.T) { h := newTestHandler(t) w := httptest.NewRecorder() r := httptest.NewRequest(http.MethodGet, "/", nil) r.SetPathValue("code", "") h.HandleRedirect(w, r) if w.Code != http.StatusBadRequest { t.Fatalf("expected status 400 for empty code, got %d", w.Code) } } func TestStore_IncrementVisit(t *testing.T) { h := newTestHandler(t) // Create a short link w := httptest.NewRecorder() r := httptest.NewRequest(http.MethodPost, "/api/shorten", strings.NewReader(`{"url":"https://example.com"}`)) r.Header.Set("Content-Type", "application/json") h.HandleCreate(w, r) resp := decodeBody(t, w.Body.String()) code := resp["code"] rec, err := h.store.FindByCode(code) if err != nil || rec == nil { t.Fatal("record not found after create") } if err := h.store.IncrementVisit(code); err != nil { t.Fatalf("IncrementVisit() failed: %v", err) } if err := h.store.IncrementVisit(code); err != nil { t.Fatalf("IncrementVisit() failed: %v", err) } rec, err = h.store.FindByCode(code) if err != nil || rec == nil { t.Fatal("record not found after increment") } if rec.VisitCount != 2 { t.Fatalf("expected VisitCount 2, got %d", rec.VisitCount) } }