You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 

166 lines
4.3 KiB

package main
import (
"crypto/rand"
"encoding/json"
"log"
"math/big"
"net/http"
"strings"
)
const (
base62Chars = "0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"
codeLength = 7
maxRetries = 3
)
// Handler holds references for HTTP handlers.
type Handler struct {
store *Store
baseURL string // optional; falls back to request Host
adminToken string
}
// NewHandler creates a Handler.
func NewHandler(store *Store, baseURL, adminToken string) *Handler {
return &Handler{store: store, baseURL: baseURL, adminToken: adminToken}
}
// generateCode produces a 7-character base62 code using crypto/rand.
func generateCode() (string, error) {
buf := make([]byte, codeLength)
for i := range buf {
n, err := rand.Int(rand.Reader, big.NewInt(62))
if err != nil {
return "", err
}
buf[i] = base62Chars[n.Int64()]
}
return string(buf), nil
}
// requireAdmin wraps a handler with Bearer token authentication.
// If adminToken is empty, all requests are rejected (fail-secure).
func (h *Handler) requireAdmin(next http.HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
token := strings.TrimPrefix(r.Header.Get("Authorization"), "Bearer ")
if h.adminToken == "" || token != h.adminToken {
jsonError(w, "unauthorized", http.StatusUnauthorized)
return
}
next(w, r)
}
}
// --- helpers ---
func jsonError(w http.ResponseWriter, msg string, status int) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(status)
json.NewEncoder(w).Encode(map[string]string{"error": msg})
}
// --- handlers ---
// HandleCreate handles POST /api/shorten.
func (h *Handler) HandleCreate(w http.ResponseWriter, r *http.Request) {
var req struct {
URL string `json:"url"`
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
jsonError(w, "invalid request body", http.StatusBadRequest)
return
}
req.URL = strings.TrimSpace(req.URL)
if req.URL == "" {
jsonError(w, "url is required", http.StatusBadRequest)
return
}
if !strings.HasPrefix(req.URL, "http://") && !strings.HasPrefix(req.URL, "https://") {
jsonError(w, "url must start with http:// or https://", http.StatusBadRequest)
return
}
var code string
var err error
for attempt := 0; attempt < maxRetries; attempt++ {
code, err = generateCode()
if err != nil {
log.Printf("generate code error: %v", err)
jsonError(w, "failed to generate short code", http.StatusInternalServerError)
return
}
if err = h.store.Create(code, req.URL); err == nil {
break // success
}
log.Printf("collision on code %q (attempt %d): %v", code, attempt+1, err)
}
if err != nil {
log.Printf("store create error after %d retries: %v", maxRetries, err)
jsonError(w, "failed to create short link", http.StatusInternalServerError)
return
}
shortURL := h.resolveBaseURL(r) + "/" + code
resp := map[string]string{
"code": code,
"short_url": shortURL,
"original_url": req.URL,
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(resp)
}
// HandleRedirect handles GET /{code}, redirecting to the original URL.
func (h *Handler) HandleRedirect(w http.ResponseWriter, r *http.Request) {
code := r.PathValue("code")
if code == "" {
jsonError(w, "missing short code", http.StatusBadRequest)
return
}
// Skip API-like paths that snuck through the mux
if strings.HasPrefix(code, "api/") {
jsonError(w, "not found", http.StatusNotFound)
return
}
url, err := h.store.FindByCode(code)
if err != nil {
log.Printf("store find error: %v", err)
jsonError(w, "internal error", http.StatusInternalServerError)
return
}
if url == nil {
jsonError(w, "short link not found", http.StatusNotFound)
return
}
// Fire-and-forget visit counter
go func() {
if err := h.store.IncrementVisit(url.Code); err != nil {
log.Printf("increment visit error: %v", err)
}
}()
// 302 — browsers don't cache it, user can update the link later
http.Redirect(w, r, url.OriginalURL, http.StatusFound)
}
// resolveBaseURL returns the configured base URL or constructs one from the
// incoming request's Host header.
func (h *Handler) resolveBaseURL(r *http.Request) string {
if h.baseURL != "" {
return strings.TrimRight(h.baseURL, "/")
}
scheme := "http"
if r.TLS != nil {
scheme = "https"
}
return scheme + "://" + r.Host
}