Spaces:
Runtime error
Runtime error
package main | |
import ( | |
"context" | |
"fmt" | |
"io" | |
"log" | |
"net/http" | |
"net/url" | |
"os" | |
"path/filepath" | |
"strings" | |
got "github.com/joho/godotenv" | |
"bytes" | |
"encoding/base64" | |
"encoding/json" | |
"unicode/utf8" | |
"github.com/gin-gonic/gin" | |
"github.com/jackc/pgx/v5" | |
"github.com/ledongthuc/pdf" | |
"github.com/ollama/ollama/api" | |
"github.com/pgvector/pgvector-go" | |
) | |
// @title RAGTAG API | |
// @version 1.0 | |
// @description This is the API for the RAGTAG system. | |
// @termsOfService http://swagger.io/terms/ | |
// @contact.name James Campbell | |
// @contact.email [email protected] | |
// @license.name Apache 2.0 | |
// @license.url http://www.apache.org/licenses/LICENSE-2.0.html | |
// @host localhost:8080 | |
// @BasePath / | |
type Session struct { | |
Messages []api.Message | |
TitleFilter string | |
} | |
var sessions = make(map[string]*Session) | |
func generateEmbedding(input string) ([]float32, error) { | |
ollamaHost := os.Getenv("OLLAMA_HOST") | |
if ollamaHost == "" { | |
ollamaHost = "localhost" // fallback to localhost if not set | |
} | |
ollamaURL, err := url.Parse(fmt.Sprintf("http://%s:11434", ollamaHost)) | |
if err != nil { | |
return nil, err | |
} | |
client := api.NewClient(ollamaURL, http.DefaultClient) | |
// Create an embedding request | |
req := &api.EmbedRequest{ | |
Model: "llama3.1", // Ensure this is an embedding-capable model | |
Input: input, | |
} | |
// Call the Embed function | |
resp, err := client.Embed(context.Background(), req) | |
if err != nil { | |
return nil, err | |
} | |
return resp.Embeddings[0], nil | |
} | |
func insertItem(conn *pgx.Conn, title string, docText string, embedding []float32) error { | |
// Combine title and docText for embedding | |
combinedText := title + " " + docText | |
_, err := conn.Exec(context.Background(), | |
"INSERT INTO items (title, doc, embedding) VALUES ($1, $2, $3)", | |
title, combinedText, pgvector.NewVector(embedding)) | |
return err | |
} | |
func queryEmbeddings(conn *pgx.Conn, query string, session *Session, c *gin.Context) error { | |
// Generate embedding for the query | |
queryEmbedding, err := generateEmbedding(query) | |
if err != nil { | |
return err | |
} | |
// Prepare the SQL query | |
sqlQuery := "SELECT doc, COALESCE(title, 'Untitled') FROM items" | |
if session.TitleFilter != "" { | |
sqlQuery += fmt.Sprintf(" WHERE title LIKE '%%%s%%'", session.TitleFilter) | |
} | |
sqlQuery += " ORDER BY embedding <-> $1 LIMIT 5" | |
// Query the database for similar documents | |
rows, err := conn.Query(context.Background(), sqlQuery, pgvector.NewVector(queryEmbedding)) | |
if err != nil { | |
return err | |
} | |
defer rows.Close() | |
var docs []string | |
var sources []string | |
for rows.Next() { | |
var doc, title string | |
if err := rows.Scan(&doc, &title); err != nil { | |
return err | |
} | |
docs = append(docs, doc) | |
sources = append(sources, fmt.Sprintf("Source: %s", title)) | |
} | |
// Combine the retrieved documents | |
contextText := strings.Join(docs, "\n\n") | |
// Create a chat request | |
ollamaHost := os.Getenv("OLLAMA_HOST") | |
if ollamaHost == "" { | |
ollamaHost = "localhost" // fallback to localhost if not set | |
} | |
ollamaURL, err := url.Parse(fmt.Sprintf("http://%s:11434", ollamaHost)) | |
if err != nil { | |
return err | |
} | |
client := api.NewClient(ollamaURL, http.DefaultClient) | |
// Add the new query to the session | |
session.Messages = append(session.Messages, api.Message{Role: "user", Content: query}) | |
// Prepare the messages for the chat request | |
messages := []api.Message{ | |
{Role: "system", Content: "You are an assistant that answers questions based on the given context."}, | |
{Role: "user", Content: "Here's the context:\n" + contextText}, | |
} | |
messages = append(messages, session.Messages...) | |
req := &api.ChatRequest{ | |
Model: "llama3.1", | |
Messages: messages, | |
Stream: new(bool), // Use new(bool) to create a pointer to a boolean | |
} | |
*req.Stream = true // Set the value to true | |
// Call the Chat function with streaming | |
err = client.Chat(context.Background(), req, func(resp api.ChatResponse) error { | |
// Send the raw content without any modifications | |
if resp.Message.Content != "" { | |
c.SSEvent("message", resp.Message.Content) | |
c.Writer.Flush() // Ensure the content is sent immediately | |
} | |
return nil | |
}) | |
if err != nil { | |
return err | |
} | |
// Add the AI response to the session | |
session.Messages = append(session.Messages, api.Message{Role: "assistant", Content: "Response sent via streaming"}) | |
return nil | |
} | |
func getDocuments(conn *pgx.Conn) ([]map[string]interface{}, error) { | |
rows, err := conn.Query(context.Background(), "SELECT DISTINCT ON (SPLIT_PART(title, '_chunk_', 1)) SPLIT_PART(title, '_chunk_', 1) as title, COUNT(*) as count FROM items GROUP BY SPLIT_PART(title, '_chunk_', 1)") | |
if err != nil { | |
return nil, err | |
} | |
defer rows.Close() | |
var documents []map[string]interface{} | |
for rows.Next() { | |
var title string | |
var count int | |
if err := rows.Scan(&title, &count); err != nil { | |
return nil, err | |
} | |
documents = append(documents, map[string]interface{}{ | |
"title": title, | |
"count": count, | |
}) | |
} | |
return documents, nil | |
} | |
func deleteDocument(conn *pgx.Conn, title string) error { | |
_, err := conn.Exec(context.Background(), "DELETE FROM items WHERE title LIKE $1 || '%'", title) | |
return err | |
} | |
func uploadDocument(c *gin.Context, conn *pgx.Conn) { | |
title := c.PostForm("title") | |
file, header, err := c.Request.FormFile("file") | |
if err != nil { | |
log.Printf("Error getting file: %v", err) | |
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) | |
return | |
} | |
defer file.Close() | |
// Create uploads directory if it doesn't exist | |
uploadsDir := "uploads" | |
if err := os.MkdirAll(uploadsDir, 0755); err != nil { | |
log.Printf("Error creating uploads directory: %v", err) | |
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to create uploads directory"}) | |
return | |
} | |
filename := filepath.Join(uploadsDir, header.Filename) | |
out, err := os.Create(filename) | |
if err != nil { | |
log.Printf("Error creating file: %v", err) | |
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) | |
return | |
} | |
defer out.Close() | |
_, err = io.Copy(out, file) | |
if err != nil { | |
log.Printf("Error copying file: %v", err) | |
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) | |
return | |
} | |
// Debug: log file size and first 16 bytes | |
stat, statErr := os.Stat(filename) | |
if statErr == nil { | |
log.Printf("Uploaded file size: %d bytes", stat.Size()) | |
fcheck, ferr := os.Open(filename) | |
if ferr == nil { | |
buf := make([]byte, 16) | |
n, _ := fcheck.Read(buf) | |
log.Printf("First 16 bytes: % x", buf[:n]) | |
fcheck.Close() | |
} | |
} | |
var textContent string | |
ext := strings.ToLower(filepath.Ext(filename)) | |
if ext == ".jpg" || ext == ".jpeg" || ext == ".png" { | |
// Generate image summary using the llava model | |
summary, err := generateImageSummary(filename) | |
if err != nil { | |
log.Printf("Error generating image summary: %v", err) | |
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) | |
return | |
} | |
textContent = summary | |
} else if ext == ".pdf" { | |
// Check PDF signature before parsing | |
f, rErr := os.Open(filename) | |
if rErr != nil { | |
log.Printf("Error opening PDF: %v", rErr) | |
c.JSON(http.StatusInternalServerError, gin.H{"error": rErr.Error()}) | |
return | |
} | |
defer f.Close() | |
buf := make([]byte, 5) | |
_, err := f.Read(buf) | |
if err != nil || string(buf) != "%PDF-" { | |
log.Printf("Uploaded file is not a valid PDF (missing %PDF- header): %s", filename) | |
c.JSON(http.StatusBadRequest, gin.H{"error": "Uploaded file is not a valid PDF (missing %PDF- header)"}) | |
return | |
} | |
stat, _ := f.Stat() | |
// Loosen EOF check: search last 1KB for %%EOF | |
eofCheckSize := int64(1024) | |
if stat.Size() < eofCheckSize { | |
eofCheckSize = stat.Size() | |
} | |
endBuf := make([]byte, eofCheckSize) | |
_, err = f.ReadAt(endBuf, stat.Size()-eofCheckSize) | |
if err != nil || !strings.Contains(string(endBuf), "%%EOF") { | |
log.Printf("Uploaded file is not a valid PDF (missing %%EOF): %s", filename) | |
c.JSON(http.StatusBadRequest, gin.H{"error": "Uploaded file is not a valid PDF (missing %%EOF)"}) | |
return | |
} | |
// Reset file pointer for pdf.NewReader | |
f.Seek(0, 0) | |
reader, pdfErr := pdf.NewReader(f, stat.Size()) | |
if pdfErr != nil { | |
log.Printf("Error reading PDF: %v", pdfErr) | |
c.JSON(http.StatusBadRequest, gin.H{"error": "Uploaded file is not a valid PDF or is corrupted"}) | |
return | |
} | |
var sb strings.Builder | |
for i := 1; i <= reader.NumPage(); i++ { | |
page := reader.Page(i) | |
if page.V.IsNull() { | |
continue | |
} | |
content, err := page.GetPlainText(nil) | |
if err != nil { | |
log.Printf("Error extracting text from page %d: %v", i, err) | |
continue | |
} | |
sb.WriteString(content) | |
} | |
textContent = sb.String() | |
log.Printf("Extracted text length: %d", len(textContent)) | |
} else { | |
content, err := os.ReadFile(filename) | |
if err != nil { | |
log.Printf("Error reading file: %v", err) | |
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) | |
return | |
} | |
textContent = string(content) | |
} | |
// Remove null bytes (Postgres TEXT cannot contain 0x00) | |
textContent = strings.ReplaceAll(textContent, "\x00", "") | |
// Validate UTF-8 | |
if !utf8.ValidString(textContent) { | |
log.Printf("Invalid UTF-8 detected in document: %s", filename) | |
c.JSON(http.StatusBadRequest, gin.H{"error": "Uploaded document is not valid UTF-8"}) | |
return | |
} | |
// Generate embedding for the text content using llama3.1 | |
embedding, err := generateEmbedding(textContent) | |
if err != nil { | |
log.Printf("Error generating embedding: %v", err) | |
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) | |
return | |
} | |
// Insert the document into the database | |
err = insertItem(conn, title, textContent, embedding) | |
if err != nil { | |
log.Printf("Error inserting item: %v", err) | |
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) | |
return | |
} | |
c.JSON(http.StatusOK, gin.H{"message": "Document uploaded and processed successfully"}) | |
} | |
func chunkText(text string, chunkSize int) []string { | |
words := strings.Fields(text) | |
var chunks []string | |
for i := 0; i < len(words); i += chunkSize { | |
end := i + chunkSize | |
if end > len(words) { | |
end = len(words) | |
} | |
chunks = append(chunks, strings.Join(words[i:end], " ")) | |
} | |
return chunks | |
} | |
func generateImageSummary(imagePath string) (string, error) { | |
imageData, err := os.ReadFile(imagePath) | |
if err != nil { | |
return "", fmt.Errorf("failed to read image file: %w", err) | |
} | |
base64Image := base64.StdEncoding.EncodeToString(imageData) | |
payload := map[string]interface{}{ | |
"model": "llava", | |
"prompt": "Describe this image in detail:", | |
"images": []string{base64Image}, | |
"stream": true, | |
} | |
jsonPayload, err := json.Marshal(payload) | |
if err != nil { | |
return "", fmt.Errorf("failed to marshal JSON payload: %w", err) | |
} | |
ollamaHost := os.Getenv("OLLAMA_HOST") | |
if ollamaHost == "" { | |
ollamaHost = "localhost" | |
} | |
url := fmt.Sprintf("http://%s:11434/api/generate", ollamaHost) | |
resp, err := http.Post(url, "application/json", bytes.NewBuffer(jsonPayload)) | |
if err != nil { | |
return "", fmt.Errorf("failed to send POST request: %w", err) | |
} | |
defer resp.Body.Close() | |
if resp.StatusCode != http.StatusOK { | |
body, _ := io.ReadAll(resp.Body) | |
return "", fmt.Errorf("unexpected response status: %d, body: %s", resp.StatusCode, string(body)) | |
} | |
var summary strings.Builder | |
decoder := json.NewDecoder(resp.Body) | |
for { | |
var result struct { | |
Response string `json:"response"` | |
Done bool `json:"done"` | |
} | |
if err := decoder.Decode(&result); err != nil { | |
if err == io.EOF { | |
break | |
} | |
return "", fmt.Errorf("failed to decode JSON response: %w", err) | |
} | |
summary.WriteString(result.Response) | |
if result.Done { | |
break | |
} | |
} | |
if summary.Len() == 0 { | |
return "", fmt.Errorf("empty response from llava model") | |
} | |
fmt.Println("The summary of the image is: ", summary.String()) | |
return summary.String(), nil | |
} | |
func main() { | |
// Set up the database connection | |
// load env variables | |
got.Load() | |
conn, err := pgx.Connect(context.Background(), os.Getenv("DB_URL")) | |
if err != nil { | |
log.Fatal("Unable to connect to database:", err) | |
} | |
defer conn.Close(context.Background()) | |
// Set up the Gin router | |
r := gin.Default() | |
// Define the /add_document endpoint | |
r.POST("/add_document", func(c *gin.Context) { | |
var request struct { | |
Title string `json:"title"` | |
DocText string `json:"doc_text"` | |
} | |
if err := c.BindJSON(&request); err != nil { | |
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) | |
return | |
} | |
// Generate the embedding | |
embedding, err := generateEmbedding(request.DocText) | |
if err != nil { | |
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) | |
return | |
} | |
// Insert the document and its embedding into the items table | |
err = insertItem(conn, request.Title, request.DocText, embedding) | |
if err != nil { | |
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) | |
return | |
} | |
c.JSON(http.StatusOK, gin.H{"message": "Document chunk embedded and stored successfully!"}) | |
}) | |
// Add the new /query endpoint | |
r.POST("/query", func(c *gin.Context) { | |
var request struct { | |
Query string `json:"query"` | |
SessionID string `json:"sessionId"` | |
} | |
if err := c.BindJSON(&request); err != nil { | |
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) | |
return | |
} | |
session, ok := sessions[request.SessionID] | |
if !ok { | |
session = &Session{ | |
Messages: []api.Message{ | |
{Role: "system", Content: "You are an assistant that answers questions based on the given context."}, | |
}, | |
TitleFilter: "", | |
} | |
sessions[request.SessionID] = session | |
} | |
// Check for @title in the query | |
if strings.Contains(request.Query, "@") { | |
parts := strings.Split(request.Query, "@") | |
if len(parts) > 1 { | |
session.TitleFilter = strings.Split(parts[1], " ")[0] | |
request.Query = strings.Replace(request.Query, "@"+session.TitleFilter, "", 1) | |
} | |
} | |
c.Header("Content-Type", "text/event-stream") | |
c.Header("Cache-Control", "no-cache") | |
c.Header("Connection", "keep-alive") | |
c.Header("Access-Control-Allow-Origin", "*") | |
c.Header("Access-Control-Allow-Credentials", "true") | |
c.Header("Access-Control-Allow-Headers", "Content-Type") | |
c.Header("Access-Control-Allow-Methods", "POST") | |
c.Header("encoding", "chunked") | |
err := queryEmbeddings(conn, request.Query, session, c) | |
if err != nil { | |
c.SSEvent("error", err.Error()) | |
} | |
c.SSEvent("done", "") | |
}) | |
// Serve the index.html file | |
r.GET("/", func(c *gin.Context) { | |
c.File("index.html") | |
}) | |
// Serve the docmanager.html file | |
r.GET("/docmanager", func(c *gin.Context) { | |
c.File("docmanager.html") | |
}) | |
// Add a new endpoint to fetch documents | |
r.GET("/documents", func(c *gin.Context) { | |
documents, err := getDocuments(conn) | |
if err != nil { | |
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) | |
return | |
} | |
c.JSON(http.StatusOK, documents) | |
}) | |
// Add a new endpoint to delete documents | |
r.POST("/delete_document", func(c *gin.Context) { | |
var request struct { | |
Title string `json:"title"` | |
} | |
if err := c.BindJSON(&request); err != nil { | |
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) | |
return | |
} | |
err := deleteDocument(conn, request.Title) | |
if err != nil { | |
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) | |
return | |
} | |
c.JSON(http.StatusOK, gin.H{"message": "Document deleted successfully"}) | |
}) | |
// Add a new endpoint to upload documents | |
r.POST("/upload_document", func(c *gin.Context) { | |
uploadDocument(c, conn) | |
}) | |
// Add a new endpoint to clear the chat session | |
r.POST("/clear_session", func(c *gin.Context) { | |
var request struct { | |
SessionID string `json:"sessionId"` | |
} | |
if err := c.BindJSON(&request); err != nil { | |
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) | |
return | |
} | |
delete(sessions, request.SessionID) | |
c.JSON(http.StatusOK, gin.H{"message": "Chat session cleared successfully"}) | |
}) | |
// Add a new endpoint to check if Twitter data exists | |
r.GET("/check_data", func(c *gin.Context) { | |
rows, err := conn.Query(context.Background(), "SELECT DISTINCT title FROM items WHERE title LIKE '%twitter%'") | |
if err != nil { | |
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) | |
return | |
} | |
defer rows.Close() | |
var titles []string | |
for rows.Next() { | |
var title string | |
if err := rows.Scan(&title); err != nil { | |
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) | |
return | |
} | |
titles = append(titles, title) | |
} | |
c.JSON(http.StatusOK, gin.H{"twitter_titles": titles}) | |
}) | |
// Serve the describer.html file | |
r.GET("/describer", func(c *gin.Context) { | |
c.File("describer.html") | |
}) | |
// Handle image description | |
r.POST("/describe_image", func(c *gin.Context) { | |
file, _, err := c.Request.FormFile("file") | |
if err != nil { | |
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) | |
return | |
} | |
defer file.Close() | |
// Create a temporary file to store the uploaded image | |
tempFile, err := os.CreateTemp("", "uploaded-*.jpg") | |
if err != nil { | |
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) | |
return | |
} | |
defer os.Remove(tempFile.Name()) | |
defer tempFile.Close() | |
// Copy the uploaded file to the temporary file | |
_, err = io.Copy(tempFile, file) | |
if err != nil { | |
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) | |
return | |
} | |
c.Header("Content-Type", "text/event-stream") | |
c.Header("Cache-Control", "no-cache") | |
c.Header("Connection", "keep-alive") | |
c.Header("Access-Control-Allow-Origin", "*") | |
c.Header("Access-Control-Allow-Credentials", "true") | |
c.Header("Access-Control-Allow-Headers", "Content-Type") | |
c.Header("Access-Control-Allow-Methods", "POST") | |
c.Header("encoding", "chunked") | |
imageData, err := os.ReadFile(tempFile.Name()) | |
if err != nil { | |
c.SSEvent("error", err.Error()) | |
return | |
} | |
base64Image := base64.StdEncoding.EncodeToString(imageData) | |
payload := map[string]interface{}{ | |
"model": "llava", | |
"prompt": "Describe this image in detail:", | |
"images": []string{base64Image}, | |
"stream": true, | |
} | |
jsonPayload, err := json.Marshal(payload) | |
if err != nil { | |
c.SSEvent("error", err.Error()) | |
return | |
} | |
ollamaHost := os.Getenv("OLLAMA_HOST") | |
if ollamaHost == "" { | |
ollamaHost = "localhost" | |
} | |
url := fmt.Sprintf("http://%s:11434/api/generate", ollamaHost) | |
resp, err := http.Post(url, "application/json", bytes.NewBuffer(jsonPayload)) | |
if err != nil { | |
c.SSEvent("error", err.Error()) | |
return | |
} | |
defer resp.Body.Close() | |
if resp.StatusCode != http.StatusOK { | |
body, _ := io.ReadAll(resp.Body) | |
c.SSEvent("error", fmt.Sprintf("Unexpected response status: %d, body: %s", resp.StatusCode, string(body))) | |
return | |
} | |
decoder := json.NewDecoder(resp.Body) | |
for { | |
var result struct { | |
Response string `json:"response"` | |
Done bool `json:"done"` | |
} | |
if err := decoder.Decode(&result); err != nil { | |
if err == io.EOF { | |
break | |
} | |
c.SSEvent("error", err.Error()) | |
return | |
} | |
if result.Response != "" { | |
c.SSEvent("message", result.Response) | |
c.Writer.Flush() // Ensure the content is sent immediately | |
} | |
if result.Done { | |
break | |
} | |
} | |
c.SSEvent("done", "") | |
}) | |
// Run the Gin server | |
r.Run(":8080") | |
} | |