Spaces:
Running
on
Zero
Running
on
Zero
Refactoring
Browse files
main.py
CHANGED
@@ -12,6 +12,8 @@ import time
|
|
12 |
|
13 |
# Constants
|
14 |
GREETING = "Hi there! I'm an AI agent that uses a [retrieval-augmented generation](https://en.wikipedia.org/wiki/Retrieval-augmented_generation) pipeline to answer questions about research by the Design Research Collective. And the best part is that I always cite my ssources! What can I tell you about today?"
|
|
|
|
|
15 |
|
16 |
# Load the dataset and convert to pandas
|
17 |
full_data = datasets.load_dataset("ccm/publications")["train"].to_pandas()
|
@@ -34,7 +36,7 @@ index.train(vectors)
|
|
34 |
index.add(vectors)
|
35 |
|
36 |
# Load the model for later use in embeddings
|
37 |
-
model = sentence_transformers.SentenceTransformer(
|
38 |
|
39 |
# Define the search function
|
40 |
def search(query: str, k: int) -> tuple[str]:
|
@@ -60,11 +62,10 @@ def search(query: str, k: int) -> tuple[str]:
|
|
60 |
|
61 |
|
62 |
# Create an LLM pipeline that we can send queries to
|
63 |
-
|
64 |
-
tokenizer = transformers.AutoTokenizer.from_pretrained(model_name)
|
65 |
streamer = transformers.TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
|
66 |
chatmodel = transformers.AutoModelForCausalLM.from_pretrained(
|
67 |
-
|
68 |
torch_dtype="auto",
|
69 |
device_map="auto"
|
70 |
)
|
|
|
12 |
|
13 |
# Constants
|
14 |
GREETING = "Hi there! I'm an AI agent that uses a [retrieval-augmented generation](https://en.wikipedia.org/wiki/Retrieval-augmented_generation) pipeline to answer questions about research by the Design Research Collective. And the best part is that I always cite my ssources! What can I tell you about today?"
|
15 |
+
EMBEDDING_MODEL_NAME = "allenai-specter"
|
16 |
+
LLM_MODEL_NAME = "Qwen/Qwen2-7B-Instruct"
|
17 |
|
18 |
# Load the dataset and convert to pandas
|
19 |
full_data = datasets.load_dataset("ccm/publications")["train"].to_pandas()
|
|
|
36 |
index.add(vectors)
|
37 |
|
38 |
# Load the model for later use in embeddings
|
39 |
+
model = sentence_transformers.SentenceTransformer(EMBEDDING_MODEL_NAME)
|
40 |
|
41 |
# Define the search function
|
42 |
def search(query: str, k: int) -> tuple[str]:
|
|
|
62 |
|
63 |
|
64 |
# Create an LLM pipeline that we can send queries to
|
65 |
+
tokenizer = transformers.AutoTokenizer.from_pretrained(LLM_MODEL_NAME)
|
|
|
66 |
streamer = transformers.TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
|
67 |
chatmodel = transformers.AutoModelForCausalLM.from_pretrained(
|
68 |
+
LLM_MODEL_NAME,
|
69 |
torch_dtype="auto",
|
70 |
device_map="auto"
|
71 |
)
|