|
package monica |
|
|
|
import ( |
|
"bufio" |
|
"fmt" |
|
"io" |
|
"log" |
|
"sync" |
|
"time" |
|
|
|
"monica-proxy/internal/types" |
|
"monica-proxy/internal/utils" |
|
"net/http" |
|
"strings" |
|
|
|
"github.com/bytedance/sonic" |
|
"github.com/sashabaranov/go-openai" |
|
) |
|
|
|
const ( |
|
sseObject = "chat.completion.chunk" |
|
sseFinish = "[DONE]" |
|
flushInterval = 100 * time.Millisecond |
|
bufferSize = 4096 |
|
) |
|
|
|
|
|
type SSEData struct { |
|
Text string `json:"text"` |
|
Finished bool `json:"finished"` |
|
} |
|
|
|
var sseDataPool = sync.Pool{ |
|
New: func() interface{} { |
|
return &SSEData{} |
|
}, |
|
} |
|
|
|
|
|
func StreamMonicaSSEToClient(model string, w io.Writer, r io.Reader) error { |
|
reader := bufio.NewReaderSize(r, bufferSize) |
|
writer := bufio.NewWriterSize(w, bufferSize) |
|
defer writer.Flush() |
|
|
|
chatId := utils.RandStringUsingMathRand(29) |
|
now := time.Now().Unix() |
|
fingerprint := utils.RandStringUsingMathRand(10) |
|
|
|
|
|
ticker := time.NewTicker(flushInterval) |
|
defer ticker.Stop() |
|
|
|
|
|
done := make(chan struct{}) |
|
defer close(done) |
|
|
|
|
|
go func() { |
|
for { |
|
select { |
|
case <-ticker.C: |
|
if f, ok := w.(http.Flusher); ok { |
|
writer.Flush() |
|
f.Flush() |
|
} |
|
case <-done: |
|
return |
|
} |
|
} |
|
}() |
|
|
|
for { |
|
line, err := reader.ReadString('\n') |
|
if err != nil { |
|
if err == io.EOF { |
|
return nil |
|
} |
|
return fmt.Errorf("read error: %w", err) |
|
} |
|
|
|
|
|
if !strings.HasPrefix(line, "data: ") { |
|
continue |
|
} |
|
|
|
jsonStr := strings.TrimPrefix(line, "data: ") |
|
if jsonStr == "" { |
|
continue |
|
} |
|
|
|
|
|
sseObj := sseDataPool.Get().(*SSEData) |
|
if err := sonic.UnmarshalString(jsonStr, sseObj); err != nil { |
|
sseDataPool.Put(sseObj) |
|
|
|
log.Printf("Error unmarshaling SSE data: %v", err) |
|
continue |
|
} |
|
|
|
|
|
var sseMsg types.ChatCompletionStreamResponse |
|
if sseObj.Finished { |
|
sseMsg = types.ChatCompletionStreamResponse{ |
|
ID: "chatcmpl-" + chatId, |
|
Object: sseObject, |
|
Created: now, |
|
Model: model, |
|
Choices: []types.ChatCompletionStreamChoice{ |
|
{ |
|
Index: 0, |
|
Delta: openai.ChatCompletionStreamChoiceDelta{ |
|
Role: openai.ChatMessageRoleAssistant, |
|
}, |
|
FinishReason: openai.FinishReasonStop, |
|
}, |
|
}, |
|
} |
|
} else { |
|
sseMsg = types.ChatCompletionStreamResponse{ |
|
ID: "chatcmpl-" + chatId, |
|
Object: sseObject, |
|
SystemFingerprint: fingerprint, |
|
Created: now, |
|
Model: model, |
|
Choices: []types.ChatCompletionStreamChoice{ |
|
{ |
|
Index: 0, |
|
Delta: openai.ChatCompletionStreamChoiceDelta{ |
|
Role: openai.ChatMessageRoleAssistant, |
|
Content: sseObj.Text, |
|
}, |
|
FinishReason: openai.FinishReasonNull, |
|
}, |
|
}, |
|
} |
|
} |
|
|
|
var sb strings.Builder |
|
sb.WriteString("data: ") |
|
sendLine, _ := sonic.MarshalString(sseMsg) |
|
sb.WriteString(sendLine) |
|
sb.WriteString("\n\n") |
|
|
|
|
|
if _, err := writer.WriteString(sb.String()); err != nil { |
|
sseDataPool.Put(sseObj) |
|
return fmt.Errorf("write error: %w", err) |
|
} |
|
|
|
|
|
if sseObj.Finished { |
|
writer.WriteString(fmt.Sprintf("data: %s\n\n", sseFinish)) |
|
writer.Flush() |
|
if f, ok := w.(http.Flusher); ok { |
|
f.Flush() |
|
} |
|
sseDataPool.Put(sseObj) |
|
return nil |
|
} |
|
|
|
|
|
sseDataPool.Put(sseObj) |
|
} |
|
} |
|
|