You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 

239 lines
6.7 KiB

package main
import (
"encoding/json"
"net/http"
"net/http/httptest"
"path/filepath"
"strings"
"testing"
)
func newTestHandler(t *testing.T) *Handler {
t.Helper()
store, err := NewStore(filepath.Join(t.TempDir(), "test.db"))
if err != nil {
t.Fatalf("NewStore() failed: %v", err)
}
t.Cleanup(func() { store.Close() })
return NewHandler(store, "http://xieyaxin.top:8899", "test-token")
}
func decodeBody(t *testing.T, body string) map[string]string {
t.Helper()
var resp map[string]string
if err := json.NewDecoder(strings.NewReader(body)).Decode(&resp); err != nil {
t.Fatalf("failed to decode response: %v", err)
}
return resp
}
func TestHandleCreate_Valid(t *testing.T) {
h := newTestHandler(t)
body := `{"url":"https://example.com"}`
w := httptest.NewRecorder()
r := httptest.NewRequest(http.MethodPost, "/api/shorten", strings.NewReader(body))
r.Header.Set("Content-Type", "application/json")
h.HandleCreate(w, r)
if w.Code != http.StatusOK {
t.Fatalf("expected status 200, got %d: %s", w.Code, w.Body.String())
}
resp := decodeBody(t, w.Body.String())
if len(resp["code"]) != 7 {
t.Fatalf("expected code length 7, got %d: %q", len(resp["code"]), resp["code"])
}
if resp["short_url"] != "http://xieyaxin.top:8899/"+resp["code"] {
t.Fatalf("expected short_url %q, got %q", "http://xieyaxin.top:8899/"+resp["code"], resp["short_url"])
}
if resp["original_url"] != "https://example.com" {
t.Fatalf("expected original_url %q, got %q", "https://example.com", resp["original_url"])
}
}
func TestHandleCreate_Consecutive(t *testing.T) {
h := newTestHandler(t)
body1 := `{"url":"https://example.com/first"}`
w1 := httptest.NewRecorder()
r1 := httptest.NewRequest(http.MethodPost, "/api/shorten", strings.NewReader(body1))
r1.Header.Set("Content-Type", "application/json")
h.HandleCreate(w1, r1)
body2 := `{"url":"https://example.com/second"}`
w2 := httptest.NewRecorder()
r2 := httptest.NewRequest(http.MethodPost, "/api/shorten", strings.NewReader(body2))
r2.Header.Set("Content-Type", "application/json")
h.HandleCreate(w2, r2)
resp1 := decodeBody(t, w1.Body.String())
resp2 := decodeBody(t, w2.Body.String())
if len(resp1["code"]) != 7 {
t.Fatalf("expected code length 7, got %d: %q", len(resp1["code"]), resp1["code"])
}
if len(resp2["code"]) != 7 {
t.Fatalf("expected code length 7, got %d: %q", len(resp2["code"]), resp2["code"])
}
if resp1["code"] == resp2["code"] {
t.Fatal("consecutive short links should have different codes")
}
}
func TestHandleCreate_MissingURL(t *testing.T) {
h := newTestHandler(t)
w := httptest.NewRecorder()
r := httptest.NewRequest(http.MethodPost, "/api/shorten", strings.NewReader(`{}`))
r.Header.Set("Content-Type", "application/json")
h.HandleCreate(w, r)
if w.Code != http.StatusBadRequest {
t.Fatalf("expected status 400, got %d: %s", w.Code, w.Body.String())
}
resp := decodeBody(t, w.Body.String())
if resp["error"] == "" {
t.Fatal("expected error message in response")
}
}
func TestHandleCreate_EmptyURL(t *testing.T) {
h := newTestHandler(t)
w := httptest.NewRecorder()
r := httptest.NewRequest(http.MethodPost, "/api/shorten", strings.NewReader(`{"url":""}`))
r.Header.Set("Content-Type", "application/json")
h.HandleCreate(w, r)
if w.Code != http.StatusBadRequest {
t.Fatalf("expected status 400, got %d", w.Code)
}
}
func TestHandleCreate_InvalidScheme(t *testing.T) {
h := newTestHandler(t)
w := httptest.NewRecorder()
r := httptest.NewRequest(http.MethodPost, "/api/shorten", strings.NewReader(`{"url":"ftp://example.com"}`))
r.Header.Set("Content-Type", "application/json")
h.HandleCreate(w, r)
if w.Code != http.StatusBadRequest {
t.Fatalf("expected status 400 for unsupported scheme, got %d", w.Code)
}
}
func TestHandleCreate_WrongMethod(t *testing.T) {
h := newTestHandler(t)
w := httptest.NewRecorder()
r := httptest.NewRequest(http.MethodGet, "/api/shorten", nil)
h.HandleCreate(w, r)
// HandleCreate itself does not check method (mux handles routing)
if w.Code != http.StatusBadRequest {
t.Fatalf("expected status 400 for GET request (no body), got %d", w.Code)
}
}
func TestHandleCreate_InvalidBody(t *testing.T) {
h := newTestHandler(t)
w := httptest.NewRecorder()
r := httptest.NewRequest(http.MethodPost, "/api/shorten", strings.NewReader(`not json`))
r.Header.Set("Content-Type", "application/json")
h.HandleCreate(w, r)
if w.Code != http.StatusBadRequest {
t.Fatalf("expected status 400 for invalid JSON, got %d", w.Code)
}
}
func TestHandleRedirect_Valid(t *testing.T) {
h := newTestHandler(t)
// Create a short link first
body := `{"url":"https://example.com"}`
w := httptest.NewRecorder()
r := httptest.NewRequest(http.MethodPost, "/api/shorten", strings.NewReader(body))
r.Header.Set("Content-Type", "application/json")
h.HandleCreate(w, r)
resp := decodeBody(t, w.Body.String())
code := resp["code"]
// Test redirect
w = httptest.NewRecorder()
r = httptest.NewRequest(http.MethodGet, "/"+code, nil)
r.SetPathValue("code", code)
h.HandleRedirect(w, r)
if w.Code != http.StatusFound {
t.Fatalf("expected status 302 for valid code, got %d", w.Code)
}
if loc := w.Header().Get("Location"); loc != "https://example.com" {
t.Fatalf("expected Location header %q, got %q", "https://example.com", loc)
}
}
func TestHandleRedirect_NotFound(t *testing.T) {
h := newTestHandler(t)
w := httptest.NewRecorder()
r := httptest.NewRequest(http.MethodGet, "/nonexist", nil)
r.SetPathValue("code", "nonexist")
h.HandleRedirect(w, r)
if w.Code != http.StatusNotFound {
t.Fatalf("expected status 404 for non-existent code, got %d", w.Code)
}
}
func TestHandleRedirect_EmptyCode(t *testing.T) {
h := newTestHandler(t)
w := httptest.NewRecorder()
r := httptest.NewRequest(http.MethodGet, "/", nil)
r.SetPathValue("code", "")
h.HandleRedirect(w, r)
if w.Code != http.StatusBadRequest {
t.Fatalf("expected status 400 for empty code, got %d", w.Code)
}
}
func TestStore_IncrementVisit(t *testing.T) {
h := newTestHandler(t)
// Create a short link
w := httptest.NewRecorder()
r := httptest.NewRequest(http.MethodPost, "/api/shorten", strings.NewReader(`{"url":"https://example.com"}`))
r.Header.Set("Content-Type", "application/json")
h.HandleCreate(w, r)
resp := decodeBody(t, w.Body.String())
code := resp["code"]
rec, err := h.store.FindByCode(code)
if err != nil || rec == nil {
t.Fatal("record not found after create")
}
if err := h.store.IncrementVisit(code); err != nil {
t.Fatalf("IncrementVisit() failed: %v", err)
}
if err := h.store.IncrementVisit(code); err != nil {
t.Fatalf("IncrementVisit() failed: %v", err)
}
rec, err = h.store.FindByCode(code)
if err != nil || rec == nil {
t.Fatal("record not found after increment")
}
if rec.VisitCount != 2 {
t.Fatalf("expected VisitCount 2, got %d", rec.VisitCount)
}
}