acecalisto3 commited on
Commit
43154fe
·
verified ·
1 Parent(s): 1eda650

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -4
app.py CHANGED
@@ -19,6 +19,16 @@ from langchain.memory import ConversationBufferMemory
19
  from langchain.chains.question_answering import load_qa_chain
20
  from langchain.document_loaders import TextLoader
21
  from langchain.text_splitter import CharacterTextSplitter
 
 
 
 
 
 
 
 
 
 
22
 
23
  # --- Constants ---
24
  MODEL_NAME = "bigscience/bloom-1b7"
@@ -27,10 +37,6 @@ TEMPERATURE = 0.7
27
  TOP_P = 0.95
28
  REPETITION_PENALTY = 1.2
29
 
30
- # --- Model & Tokenizer ---
31
- model = AutoModelForCausalLM.from_pretrained(MODEL_NAME)
32
- tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
33
-
34
  # --- Agents ---
35
  agents = {
36
  "WEB_DEV": {
 
19
  from langchain.chains.question_answering import load_qa_chain
20
  from langchain.document_loaders import TextLoader
21
  from langchain.text_splitter import CharacterTextSplitter
22
+ from transformers import pipeline, AutoTokenizer, AutoModelForSeq2SeqLM, Seq2SeqLMForCausalGeneration
23
+
24
+ def create_causal_lm(model_name: str):
25
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
26
+ model = AutoModelForSeq2SeqLM.from_pretrained(model_name).causal_decoder
27
+ return model, tokenizer
28
+
29
+ AutoModelForCausalLM = lambda model_name: create_causal_lm(model_name)[0]
30
+ AutoTokenizerForCausalLM = lambda model_name: create_causal_lm(model_name)[1]
31
+
32
 
33
  # --- Constants ---
34
  MODEL_NAME = "bigscience/bloom-1b7"
 
37
  TOP_P = 0.95
38
  REPETITION_PENALTY = 1.2
39
 
 
 
 
 
40
  # --- Agents ---
41
  agents = {
42
  "WEB_DEV": {