package main import ( "context" "fmt" "io" "log" "net/http" "net/url" "os" "path/filepath" "strings" got "github.com/joho/godotenv" "bytes" "encoding/base64" "encoding/json" "unicode/utf8" "github.com/gin-gonic/gin" "github.com/jackc/pgx/v5" "github.com/ledongthuc/pdf" "github.com/ollama/ollama/api" "github.com/pgvector/pgvector-go" ) // @title RAGTAG API // @version 1.0 // @description This is the API for the RAGTAG system. // @termsOfService http://swagger.io/terms/ // @contact.name James Campbell // @contact.email james@enigmalabs.com // @license.name Apache 2.0 // @license.url http://www.apache.org/licenses/LICENSE-2.0.html // @host localhost:8080 // @BasePath / type Session struct { Messages []api.Message TitleFilter string } var sessions = make(map[string]*Session) func generateEmbedding(input string) ([]float32, error) { ollamaHost := os.Getenv("OLLAMA_HOST") if ollamaHost == "" { ollamaHost = "localhost" // fallback to localhost if not set } ollamaURL, err := url.Parse(fmt.Sprintf("http://%s:11434", ollamaHost)) if err != nil { return nil, err } client := api.NewClient(ollamaURL, http.DefaultClient) // Create an embedding request req := &api.EmbedRequest{ Model: "llama3.1", // Ensure this is an embedding-capable model Input: input, } // Call the Embed function resp, err := client.Embed(context.Background(), req) if err != nil { return nil, err } return resp.Embeddings[0], nil } func insertItem(conn *pgx.Conn, title string, docText string, embedding []float32) error { // Combine title and docText for embedding combinedText := title + " " + docText _, err := conn.Exec(context.Background(), "INSERT INTO items (title, doc, embedding) VALUES ($1, $2, $3)", title, combinedText, pgvector.NewVector(embedding)) return err } func queryEmbeddings(conn *pgx.Conn, query string, session *Session, c *gin.Context) error { // Generate embedding for the query queryEmbedding, err := generateEmbedding(query) if err != nil { return err } // Prepare the SQL query sqlQuery := "SELECT doc, COALESCE(title, 'Untitled') FROM items" if session.TitleFilter != "" { sqlQuery += fmt.Sprintf(" WHERE title LIKE '%%%s%%'", session.TitleFilter) } sqlQuery += " ORDER BY embedding <-> $1 LIMIT 5" // Query the database for similar documents rows, err := conn.Query(context.Background(), sqlQuery, pgvector.NewVector(queryEmbedding)) if err != nil { return err } defer rows.Close() var docs []string var sources []string for rows.Next() { var doc, title string if err := rows.Scan(&doc, &title); err != nil { return err } docs = append(docs, doc) sources = append(sources, fmt.Sprintf("Source: %s", title)) } // Combine the retrieved documents contextText := strings.Join(docs, "\n\n") // Create a chat request ollamaHost := os.Getenv("OLLAMA_HOST") if ollamaHost == "" { ollamaHost = "localhost" // fallback to localhost if not set } ollamaURL, err := url.Parse(fmt.Sprintf("http://%s:11434", ollamaHost)) if err != nil { return err } client := api.NewClient(ollamaURL, http.DefaultClient) // Add the new query to the session session.Messages = append(session.Messages, api.Message{Role: "user", Content: query}) // Prepare the messages for the chat request messages := []api.Message{ {Role: "system", Content: "You are an assistant that answers questions based on the given context."}, {Role: "user", Content: "Here's the context:\n" + contextText}, } messages = append(messages, session.Messages...) req := &api.ChatRequest{ Model: "llama3.1", Messages: messages, Stream: new(bool), // Use new(bool) to create a pointer to a boolean } *req.Stream = true // Set the value to true // Call the Chat function with streaming err = client.Chat(context.Background(), req, func(resp api.ChatResponse) error { // Send the raw content without any modifications if resp.Message.Content != "" { c.SSEvent("message", resp.Message.Content) c.Writer.Flush() // Ensure the content is sent immediately } return nil }) if err != nil { return err } // Add the AI response to the session session.Messages = append(session.Messages, api.Message{Role: "assistant", Content: "Response sent via streaming"}) return nil } func getDocuments(conn *pgx.Conn) ([]map[string]interface{}, error) { rows, err := conn.Query(context.Background(), "SELECT DISTINCT ON (SPLIT_PART(title, '_chunk_', 1)) SPLIT_PART(title, '_chunk_', 1) as title, COUNT(*) as count FROM items GROUP BY SPLIT_PART(title, '_chunk_', 1)") if err != nil { return nil, err } defer rows.Close() var documents []map[string]interface{} for rows.Next() { var title string var count int if err := rows.Scan(&title, &count); err != nil { return nil, err } documents = append(documents, map[string]interface{}{ "title": title, "count": count, }) } return documents, nil } func deleteDocument(conn *pgx.Conn, title string) error { _, err := conn.Exec(context.Background(), "DELETE FROM items WHERE title LIKE $1 || '%'", title) return err } func uploadDocument(c *gin.Context, conn *pgx.Conn) { title := c.PostForm("title") file, header, err := c.Request.FormFile("file") if err != nil { log.Printf("Error getting file: %v", err) c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) return } defer file.Close() // Create uploads directory if it doesn't exist uploadsDir := "uploads" if err := os.MkdirAll(uploadsDir, 0755); err != nil { log.Printf("Error creating uploads directory: %v", err) c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to create uploads directory"}) return } filename := filepath.Join(uploadsDir, header.Filename) out, err := os.Create(filename) if err != nil { log.Printf("Error creating file: %v", err) c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return } defer out.Close() _, err = io.Copy(out, file) if err != nil { log.Printf("Error copying file: %v", err) c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return } // Debug: log file size and first 16 bytes stat, statErr := os.Stat(filename) if statErr == nil { log.Printf("Uploaded file size: %d bytes", stat.Size()) fcheck, ferr := os.Open(filename) if ferr == nil { buf := make([]byte, 16) n, _ := fcheck.Read(buf) log.Printf("First 16 bytes: % x", buf[:n]) fcheck.Close() } } var textContent string ext := strings.ToLower(filepath.Ext(filename)) if ext == ".jpg" || ext == ".jpeg" || ext == ".png" { // Generate image summary using the llava model summary, err := generateImageSummary(filename) if err != nil { log.Printf("Error generating image summary: %v", err) c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return } textContent = summary } else if ext == ".pdf" { // Check PDF signature before parsing f, rErr := os.Open(filename) if rErr != nil { log.Printf("Error opening PDF: %v", rErr) c.JSON(http.StatusInternalServerError, gin.H{"error": rErr.Error()}) return } defer f.Close() buf := make([]byte, 5) _, err := f.Read(buf) if err != nil || string(buf) != "%PDF-" { log.Printf("Uploaded file is not a valid PDF (missing %PDF- header): %s", filename) c.JSON(http.StatusBadRequest, gin.H{"error": "Uploaded file is not a valid PDF (missing %PDF- header)"}) return } stat, _ := f.Stat() // Loosen EOF check: search last 1KB for %%EOF eofCheckSize := int64(1024) if stat.Size() < eofCheckSize { eofCheckSize = stat.Size() } endBuf := make([]byte, eofCheckSize) _, err = f.ReadAt(endBuf, stat.Size()-eofCheckSize) if err != nil || !strings.Contains(string(endBuf), "%%EOF") { log.Printf("Uploaded file is not a valid PDF (missing %%EOF): %s", filename) c.JSON(http.StatusBadRequest, gin.H{"error": "Uploaded file is not a valid PDF (missing %%EOF)"}) return } // Reset file pointer for pdf.NewReader f.Seek(0, 0) reader, pdfErr := pdf.NewReader(f, stat.Size()) if pdfErr != nil { log.Printf("Error reading PDF: %v", pdfErr) c.JSON(http.StatusBadRequest, gin.H{"error": "Uploaded file is not a valid PDF or is corrupted"}) return } var sb strings.Builder for i := 1; i <= reader.NumPage(); i++ { page := reader.Page(i) if page.V.IsNull() { continue } content, err := page.GetPlainText(nil) if err != nil { log.Printf("Error extracting text from page %d: %v", i, err) continue } sb.WriteString(content) } textContent = sb.String() log.Printf("Extracted text length: %d", len(textContent)) } else { content, err := os.ReadFile(filename) if err != nil { log.Printf("Error reading file: %v", err) c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return } textContent = string(content) } // Remove null bytes (Postgres TEXT cannot contain 0x00) textContent = strings.ReplaceAll(textContent, "\x00", "") // Validate UTF-8 if !utf8.ValidString(textContent) { log.Printf("Invalid UTF-8 detected in document: %s", filename) c.JSON(http.StatusBadRequest, gin.H{"error": "Uploaded document is not valid UTF-8"}) return } // Generate embedding for the text content using llama3.1 embedding, err := generateEmbedding(textContent) if err != nil { log.Printf("Error generating embedding: %v", err) c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return } // Insert the document into the database err = insertItem(conn, title, textContent, embedding) if err != nil { log.Printf("Error inserting item: %v", err) c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return } c.JSON(http.StatusOK, gin.H{"message": "Document uploaded and processed successfully"}) } func chunkText(text string, chunkSize int) []string { words := strings.Fields(text) var chunks []string for i := 0; i < len(words); i += chunkSize { end := i + chunkSize if end > len(words) { end = len(words) } chunks = append(chunks, strings.Join(words[i:end], " ")) } return chunks } func generateImageSummary(imagePath string) (string, error) { imageData, err := os.ReadFile(imagePath) if err != nil { return "", fmt.Errorf("failed to read image file: %w", err) } base64Image := base64.StdEncoding.EncodeToString(imageData) payload := map[string]interface{}{ "model": "llava", "prompt": "Describe this image in detail:", "images": []string{base64Image}, "stream": true, } jsonPayload, err := json.Marshal(payload) if err != nil { return "", fmt.Errorf("failed to marshal JSON payload: %w", err) } ollamaHost := os.Getenv("OLLAMA_HOST") if ollamaHost == "" { ollamaHost = "localhost" } url := fmt.Sprintf("http://%s:11434/api/generate", ollamaHost) resp, err := http.Post(url, "application/json", bytes.NewBuffer(jsonPayload)) if err != nil { return "", fmt.Errorf("failed to send POST request: %w", err) } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { body, _ := io.ReadAll(resp.Body) return "", fmt.Errorf("unexpected response status: %d, body: %s", resp.StatusCode, string(body)) } var summary strings.Builder decoder := json.NewDecoder(resp.Body) for { var result struct { Response string `json:"response"` Done bool `json:"done"` } if err := decoder.Decode(&result); err != nil { if err == io.EOF { break } return "", fmt.Errorf("failed to decode JSON response: %w", err) } summary.WriteString(result.Response) if result.Done { break } } if summary.Len() == 0 { return "", fmt.Errorf("empty response from llava model") } fmt.Println("The summary of the image is: ", summary.String()) return summary.String(), nil } func main() { // Set up the database connection // load env variables got.Load() conn, err := pgx.Connect(context.Background(), os.Getenv("DB_URL")) if err != nil { log.Fatal("Unable to connect to database:", err) } defer conn.Close(context.Background()) // Set up the Gin router r := gin.Default() // Define the /add_document endpoint r.POST("/add_document", func(c *gin.Context) { var request struct { Title string `json:"title"` DocText string `json:"doc_text"` } if err := c.BindJSON(&request); err != nil { c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) return } // Generate the embedding embedding, err := generateEmbedding(request.DocText) if err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return } // Insert the document and its embedding into the items table err = insertItem(conn, request.Title, request.DocText, embedding) if err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return } c.JSON(http.StatusOK, gin.H{"message": "Document chunk embedded and stored successfully!"}) }) // Add the new /query endpoint r.POST("/query", func(c *gin.Context) { var request struct { Query string `json:"query"` SessionID string `json:"sessionId"` } if err := c.BindJSON(&request); err != nil { c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) return } session, ok := sessions[request.SessionID] if !ok { session = &Session{ Messages: []api.Message{ {Role: "system", Content: "You are an assistant that answers questions based on the given context."}, }, TitleFilter: "", } sessions[request.SessionID] = session } // Check for @title in the query if strings.Contains(request.Query, "@") { parts := strings.Split(request.Query, "@") if len(parts) > 1 { session.TitleFilter = strings.Split(parts[1], " ")[0] request.Query = strings.Replace(request.Query, "@"+session.TitleFilter, "", 1) } } c.Header("Content-Type", "text/event-stream") c.Header("Cache-Control", "no-cache") c.Header("Connection", "keep-alive") c.Header("Access-Control-Allow-Origin", "*") c.Header("Access-Control-Allow-Credentials", "true") c.Header("Access-Control-Allow-Headers", "Content-Type") c.Header("Access-Control-Allow-Methods", "POST") c.Header("encoding", "chunked") err := queryEmbeddings(conn, request.Query, session, c) if err != nil { c.SSEvent("error", err.Error()) } c.SSEvent("done", "") }) // Serve the index.html file r.GET("/", func(c *gin.Context) { c.File("index.html") }) // Serve the docmanager.html file r.GET("/docmanager", func(c *gin.Context) { c.File("docmanager.html") }) // Add a new endpoint to fetch documents r.GET("/documents", func(c *gin.Context) { documents, err := getDocuments(conn) if err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return } c.JSON(http.StatusOK, documents) }) // Add a new endpoint to delete documents r.POST("/delete_document", func(c *gin.Context) { var request struct { Title string `json:"title"` } if err := c.BindJSON(&request); err != nil { c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) return } err := deleteDocument(conn, request.Title) if err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return } c.JSON(http.StatusOK, gin.H{"message": "Document deleted successfully"}) }) // Add a new endpoint to upload documents r.POST("/upload_document", func(c *gin.Context) { uploadDocument(c, conn) }) // Add a new endpoint to clear the chat session r.POST("/clear_session", func(c *gin.Context) { var request struct { SessionID string `json:"sessionId"` } if err := c.BindJSON(&request); err != nil { c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) return } delete(sessions, request.SessionID) c.JSON(http.StatusOK, gin.H{"message": "Chat session cleared successfully"}) }) // Add a new endpoint to check if Twitter data exists r.GET("/check_data", func(c *gin.Context) { rows, err := conn.Query(context.Background(), "SELECT DISTINCT title FROM items WHERE title LIKE '%twitter%'") if err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return } defer rows.Close() var titles []string for rows.Next() { var title string if err := rows.Scan(&title); err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return } titles = append(titles, title) } c.JSON(http.StatusOK, gin.H{"twitter_titles": titles}) }) // Serve the describer.html file r.GET("/describer", func(c *gin.Context) { c.File("describer.html") }) // Handle image description r.POST("/describe_image", func(c *gin.Context) { file, _, err := c.Request.FormFile("file") if err != nil { c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) return } defer file.Close() // Create a temporary file to store the uploaded image tempFile, err := os.CreateTemp("", "uploaded-*.jpg") if err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return } defer os.Remove(tempFile.Name()) defer tempFile.Close() // Copy the uploaded file to the temporary file _, err = io.Copy(tempFile, file) if err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return } c.Header("Content-Type", "text/event-stream") c.Header("Cache-Control", "no-cache") c.Header("Connection", "keep-alive") c.Header("Access-Control-Allow-Origin", "*") c.Header("Access-Control-Allow-Credentials", "true") c.Header("Access-Control-Allow-Headers", "Content-Type") c.Header("Access-Control-Allow-Methods", "POST") c.Header("encoding", "chunked") imageData, err := os.ReadFile(tempFile.Name()) if err != nil { c.SSEvent("error", err.Error()) return } base64Image := base64.StdEncoding.EncodeToString(imageData) payload := map[string]interface{}{ "model": "llava", "prompt": "Describe this image in detail:", "images": []string{base64Image}, "stream": true, } jsonPayload, err := json.Marshal(payload) if err != nil { c.SSEvent("error", err.Error()) return } ollamaHost := os.Getenv("OLLAMA_HOST") if ollamaHost == "" { ollamaHost = "localhost" } url := fmt.Sprintf("http://%s:11434/api/generate", ollamaHost) resp, err := http.Post(url, "application/json", bytes.NewBuffer(jsonPayload)) if err != nil { c.SSEvent("error", err.Error()) return } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { body, _ := io.ReadAll(resp.Body) c.SSEvent("error", fmt.Sprintf("Unexpected response status: %d, body: %s", resp.StatusCode, string(body))) return } decoder := json.NewDecoder(resp.Body) for { var result struct { Response string `json:"response"` Done bool `json:"done"` } if err := decoder.Decode(&result); err != nil { if err == io.EOF { break } c.SSEvent("error", err.Error()) return } if result.Response != "" { c.SSEvent("message", result.Response) c.Writer.Flush() // Ensure the content is sent immediately } if result.Done { break } } c.SSEvent("done", "") }) // Run the Gin server r.Run(":8080") }