Daemontatox commited on
Commit
5a6715f
·
verified ·
1 Parent(s): e90736c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +33 -3
app.py CHANGED
@@ -1,3 +1,12 @@
 
 
 
 
 
 
 
 
 
1
  import os
2
  from dotenv import load_dotenv
3
  from langchain_community.vectorstores import Qdrant
@@ -65,7 +74,7 @@ try:
65
  client = QdrantClient(
66
  url=os.getenv("QDRANT_URL"),
67
  api_key=os.getenv("QDRANT_API_KEY"),
68
- prefer_grpc=False
69
  )
70
  except Exception as e:
71
  logger.error("Failed to connect to Qdrant. Ensure QDRANT_URL and QDRANT_API_KEY are correctly set.")
@@ -119,10 +128,31 @@ retriever = db.as_retriever(
119
  # timeout=None
120
 
121
  # )
122
- model_id = "CohereForAI/c4ai-command-r7b-12-2024"
 
 
 
 
 
 
 
 
 
 
 
123
  tokenizer = AutoTokenizer.from_pretrained(model_id)
124
- model = AutoModelForCausalLM.from_pretrained(model_id)
 
 
 
 
 
 
 
 
 
125
  pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, max_new_tokens=8192 )
 
126
  llm = HuggingFacePipeline(pipeline=pipe)
127
 
128
 
 
1
+ import subprocess
2
+
3
+ subprocess.run(
4
+ 'pip install flash-attn --no-build-isolation',
5
+ env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"},
6
+ shell=True
7
+ )
8
+
9
+
10
  import os
11
  from dotenv import load_dotenv
12
  from langchain_community.vectorstores import Qdrant
 
74
  client = QdrantClient(
75
  url=os.getenv("QDRANT_URL"),
76
  api_key=os.getenv("QDRANT_API_KEY"),
77
+ prefer_grpc=True
78
  )
79
  except Exception as e:
80
  logger.error("Failed to connect to Qdrant. Ensure QDRANT_URL and QDRANT_API_KEY are correctly set.")
 
128
  # timeout=None
129
 
130
  # )
131
+
132
+ quantization_config = BitsAndBytesConfig(
133
+ load_in_4bit=True,
134
+ bnb_4bit_compute_dtype=torch.bfloat16,
135
+ bnb_4bit_quant_type="nf4",
136
+ bnb_4bit_use_double_quant=True
137
+ )
138
+
139
+
140
+
141
+
142
+ model_id = "unsloth/phi-4"
143
  tokenizer = AutoTokenizer.from_pretrained(model_id)
144
+
145
+ model = AutoModelForCausalLM.from_pretrained(
146
+ MODEL_ID,
147
+ torch_dtype=torch.float16,
148
+ device_map="cuda",
149
+ attn_implementation="flash_attention_2",
150
+ quantization_config=quantization_config
151
+
152
+ )
153
+
154
  pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, max_new_tokens=8192 )
155
+
156
  llm = HuggingFacePipeline(pipeline=pipe)
157
 
158