ccm commited on
Commit
7e37bf7
·
verified ·
1 Parent(s): 2f0848d

Refactoring

Browse files
Files changed (1) hide show
  1. main.py +5 -4
main.py CHANGED
@@ -12,6 +12,8 @@ import time
12
 
13
  # Constants
14
  GREETING = "Hi there! I'm an AI agent that uses a [retrieval-augmented generation](https://en.wikipedia.org/wiki/Retrieval-augmented_generation) pipeline to answer questions about research by the Design Research Collective. And the best part is that I always cite my ssources! What can I tell you about today?"
 
 
15
 
16
  # Load the dataset and convert to pandas
17
  full_data = datasets.load_dataset("ccm/publications")["train"].to_pandas()
@@ -34,7 +36,7 @@ index.train(vectors)
34
  index.add(vectors)
35
 
36
  # Load the model for later use in embeddings
37
- model = sentence_transformers.SentenceTransformer("allenai-specter")
38
 
39
  # Define the search function
40
  def search(query: str, k: int) -> tuple[str]:
@@ -60,11 +62,10 @@ def search(query: str, k: int) -> tuple[str]:
60
 
61
 
62
  # Create an LLM pipeline that we can send queries to
63
- model_name = "Qwen/Qwen2-7B-Instruct"
64
- tokenizer = transformers.AutoTokenizer.from_pretrained(model_name)
65
  streamer = transformers.TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
66
  chatmodel = transformers.AutoModelForCausalLM.from_pretrained(
67
- model_name,
68
  torch_dtype="auto",
69
  device_map="auto"
70
  )
 
12
 
13
  # Constants
14
  GREETING = "Hi there! I'm an AI agent that uses a [retrieval-augmented generation](https://en.wikipedia.org/wiki/Retrieval-augmented_generation) pipeline to answer questions about research by the Design Research Collective. And the best part is that I always cite my ssources! What can I tell you about today?"
15
+ EMBEDDING_MODEL_NAME = "allenai-specter"
16
+ LLM_MODEL_NAME = "Qwen/Qwen2-7B-Instruct"
17
 
18
  # Load the dataset and convert to pandas
19
  full_data = datasets.load_dataset("ccm/publications")["train"].to_pandas()
 
36
  index.add(vectors)
37
 
38
  # Load the model for later use in embeddings
39
+ model = sentence_transformers.SentenceTransformer(EMBEDDING_MODEL_NAME)
40
 
41
  # Define the search function
42
  def search(query: str, k: int) -> tuple[str]:
 
62
 
63
 
64
  # Create an LLM pipeline that we can send queries to
65
+ tokenizer = transformers.AutoTokenizer.from_pretrained(LLM_MODEL_NAME)
 
66
  streamer = transformers.TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
67
  chatmodel = transformers.AutoModelForCausalLM.from_pretrained(
68
+ LLM_MODEL_NAME,
69
  torch_dtype="auto",
70
  device_map="auto"
71
  )