File size: 3,646 Bytes
ea24308
163b33d
1419876
79c71c4
 
9f7d07d
ea24308
1419876
 
7b06e2c
 
 
1419876
 
 
 
 
 
 
 
 
 
 
 
 
ea24308
7b06e2c
 
5c59e44
7b06e2c
 
 
 
 
 
 
 
 
 
1419876
7b06e2c
 
 
48648f7
7b06e2c
 
 
 
 
45aa608
7b06e2c
48648f7
45aa608
7b06e2c
1419876
 
7b06e2c
45aa608
7b06e2c
1419876
 
7b06e2c
1419876
7b06e2c
 
 
 
6e20083
 
 
 
7b06e2c
 
 
 
6e20083
 
 
 
 
 
 
100cdde
6e20083
 
 
 
 
7b06e2c
1419876
7b06e2c
1419876
7b06e2c
 
1419876
7b06e2c
1419876
7b06e2c
 
 
 
 
9f7d07d
7b06e2c
1419876
 
7b06e2c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import streamlit as st
import time
# Your existing demos
from assist.chat import chat as embed_chat
from assist.bayes_chat import bayes_chat
from assist.transformer_demo import transformer_next

# DeepSeek imports
from transformers import AutoModelForCausalLM, AutoTokenizer, TextGenerationPipeline
# Retrieval imports
from sentence_transformers import SentenceTransformer
import torch

st.set_page_config(page_title="RepoSage All-in-One Demo", layout="wide")
st.title("🤖 RepoSage Unified Demo")

# Cache and load DeepSeek-R1
@st.cache_resource
def load_deepseek():
    model_name = "deepseek-ai/DeepSeek-Coder-1.3B-base"
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model     = AutoModelForCausalLM.from_pretrained(model_name)
    return TextGenerationPipeline(model=model, tokenizer=tokenizer)

deepseek_gen = load_deepseek()

# Cache and load training corpus passages
@st.cache_data
def load_passages(path="RepoSage Training.txt"):
    text = open(path, encoding="utf8").read()
    paras = [p.strip() for p in text.split("\n\n") if p.strip()]
    return paras

# Cache and embed passages
@st.cache_resource
def embed_passages(passages):
    encoder = SentenceTransformer("all-MiniLM-L6-v2")
    embeddings = encoder.encode(passages, convert_to_tensor=True)
    return encoder, passages, embeddings

# Prepare RAG resources
_passages = load_passages()
_encoder, passages, passage_embs = embed_passages(_passages)

# User input
title = st.text_input("Enter your question or prompt below:")

# Define columns for five demos
col1, col2, col3, col4, col5 = st.columns(5)

# Math demo in col1
with col1:
    if st.button("DeepSeek-R1 Math Demo"):
        if not title.strip():
            st.warning("Please enter a prompt first.")
        else:
            prompt = f"You are an expert math tutor. Compute the derivative of f(x) = {title} step by step using the product rule. Solution:\n"
            with st.spinner("Working it out…"):
                out = deepseek_gen(prompt, max_new_tokens=80, do_sample=False, temperature=0.0)
            st.code(out[0]["generated_text"], language="text")

# RAG-augmented demo in col2
with col2:
    if st.button("DeepSeek-R1 RAG Demo"):
        if not title.strip():
            st.warning("Please enter a question first.")
        else:
            # 1) mark the start
            t0 = time.time()

            # retrieval
            q_emb = _encoder.encode(title, convert_to_tensor=True)
            sims = torch.nn.functional.cosine_similarity(q_emb.unsqueeze(0), passage_embs)
            topk = torch.topk(sims, k=min(3, len(passages))).indices.tolist()
            context = "\n\n".join(passages[i] for i in topk)

            t1 = time.time()
            st.write(f"⏱ Retrieval done in {t1-t0:.1f}s; generation starting…")

            # 2) generation (reduce tokens for now)
            out = deepseek_gen(
                f"Use these notes to answer:\n\n{context}\n\nQ: {title}\nA:",
                max_new_tokens=10,
                do_sample=False
            )

            t2 = time.time()
            st.write(f"⏱ Generation took {t2-t1:.1f}s (total {t2-t0:.1f}s)")
            st.write(out[0]["generated_text"])

# Embedding Q&A in col3
with col3:
    if st.button("Embedding Q&A"):
        st.write(embed_chat(title))

# Bayesian Q&A in col4
with col4:
    if st.button("Bayesian Q&A"):
        st.write(bayes_chat(title))

# Transformer Demo in col5
with col5:
    if st.button("Transformer Demo"):
        st.write(transformer_next(title))

st.markdown("---")
st.caption("DeepSeek-R1 Math, RAG, Embedding, Bayesian & Transformer demos all in one place ✅")