import gradio # for the interface import transformers # to load an LLM import sentence_transformers # to load an embedding model import faiss # to create an index import numpy # to work with vectors import pandas # to work with pandas import json # to work with JSON import datasets # to load the dataset import spaces # for GPU import threading import time # Load the dataset and convert to pandas full_data = datasets.load_dataset("ccm/publications")["train"].to_pandas() # Filter out any publications without an abstract filter = [ '"abstract": null' in json.dumps(bibdict) for bibdict in full_data["bib_dict"].values ] data = full_data[~pandas.Series(filter)] data.reset_index(inplace=True) # Create a FAISS index for fast similarity search metric = faiss.METRIC_INNER_PRODUCT vectors = numpy.stack(data["embedding"].tolist(), axis=0) index = faiss.IndexFlatL2(len(data["embedding"][0])) index.metric_type = metric faiss.normalize_L2(vectors) index.train(vectors) index.add(vectors) # Load the model for later use in embeddings model = sentence_transformers.SentenceTransformer("allenai-specter") # Define the search function def search(query: str, k: int) -> tuple[str]: query = numpy.expand_dims(model.encode(query), axis=0) faiss.normalize_L2(query) D, I = index.search(query, k) top_five = data.loc[I[0]] search_results = "You are an AI assistant who delights in helping people" \ + "learn about research from the Design Research Collective. Here are" \ + "several really cool abstracts:\n\n" references = "\n\n## References\n\n" for i in range(k): search_results += top_five["bib_dict"].values[i]["abstract"] + "\n" references += str(i+1) + ". [" + top_five["bib_dict"].values[i]["title"] + "]" \ + "(https://scholar.google.com/citations?view_op=view_citation&citation_for_view=" + top_five["author_pub_id"].values[i] + ")\n" search_results += "\nSummarize the above abstracts as you respond to the following query:" print(search_results) return search_results, references # Create an LLM pipeline that we can send queries to tokenizer = transformers.AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B-Instruct") streamer = transformers.TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) chatmodel = transformers.AutoModelForCausalLM.from_pretrained( "Qwen/Qwen2-0.5B-Instruct", torch_dtype="auto", device_map="auto" ) def preprocess(message: str) -> tuple[str]: """Applies a preprocessing step to the user's message before the LLM receives it""" block_search_results, formatted_search_results = search(message, 5) return block_search_results + message, formatted_search_results def postprocess(response: str, bypass_from_preprocessing: str) -> str: """Applies a postprocessing step to the LLM's response before the user receives it""" return response + bypass_from_preprocessing @spaces.GPU def predict(message: str, history: list[str]) -> str: """This function is responsible for crafting a response""" # Apply preprocessing message, bypass = preprocess(message) # This is some handling that is applied to the history variable to put it in a good format if isinstance(history, list): if len(history) > 0: history = history[-1] history_transformer_format = [ {"role": "assistant" if idx&1 else "user", "content": msg} for idx, msg in enumerate(history) ] + [{"role": "user", "content": message}] # Stream a response from pipe text = tokenizer.apply_chat_template( history_transformer_format, tokenize=False, add_generation_prompt=True ) model_inputs = tokenizer([text], return_tensors="pt").to("cuda:0") generate_kwargs = dict( model_inputs, streamer=streamer, max_new_tokens=512 ) t = threading.Thread(target=chatmodel.generate, kwargs=generate_kwargs) t.start() partial_message = "" for new_token in streamer: if new_token != '<': partial_message += new_token time.sleep(0.05) yield partial_message yield partial_message + bypass # Create and run the gradio interface gradio.ChatInterface( predict, examples=[ "Tell me about new research at the intersection of additive manufacturing and machine learning", "What is a physics-informed neural network and what can it be used for?", "What can agent-based models do about climate change?" ], chatbot = gradio.Chatbot(show_label=False), retry_btn = None, undo_btn = None, clear_btn=None, theme="monochrome", # cache_examples=True, ).launch(debug=True)