Daemontatox commited on
Commit
8359d12
·
verified ·
1 Parent(s): 3365a48

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +28 -58
app.py CHANGED
@@ -1,48 +1,34 @@
1
- import os
2
- from dotenv import load_dotenv
3
  from langchain_community.vectorstores import Qdrant
 
4
  from langchain_huggingface import HuggingFaceEmbeddings
5
- from langchain.llms import HuggingFacePipeline
6
- from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
7
  from langchain.prompts import ChatPromptTemplate
8
  from langchain.schema.runnable import RunnablePassthrough
9
  from langchain.schema.output_parser import StrOutputParser
10
  from qdrant_client import QdrantClient, models
11
  from langchain_qdrant import Qdrant
12
- import gradio as gr
13
- import torch
14
- import spaces
15
 
16
  # Load environment variables
17
  load_dotenv()
18
 
19
- # Verify environment variables
20
- qdrant_url = os.getenv("QDRANT_URL")
21
- qdrant_api_key = os.getenv("QDRANT_API_KEY")
22
-
23
- print(f"QDRANT_URL: {qdrant_url}")
24
- print(f"QDRANT_API_KEY: {qdrant_api_key}")
25
 
26
  # HuggingFace Embeddings
27
  embeddings = HuggingFaceEmbeddings(model_name="BAAI/bge-large-en-v1.5")
28
 
29
  # Qdrant Client Setup
30
  client = QdrantClient(
31
- url=qdrant_url,
32
- api_key=qdrant_api_key,
33
- #prefer_grpc=True
34
  )
35
 
36
- collection_name="mawared"
37
-
38
- # Check if the connection is successful
39
- try:
40
- client.get_collection(collection_name)
41
- print(f"Successfully connected to Qdrant collection: {collection_name}")
42
- except Exception as e:
43
- print(f"Failed to connect to Qdrant: {e}")
44
- raise e
45
-
46
 
47
  # Try to create collection, handle if it already exists
48
  try:
@@ -52,6 +38,7 @@ try:
52
  size=768, # GTE-large embedding size
53
  distance=models.Distance.COSINE
54
  ),
 
55
  )
56
  print(f"Created new collection: {collection_name}")
57
  except Exception as e:
@@ -73,20 +60,8 @@ retriever = db.as_retriever(
73
  search_kwargs={"k": 5}
74
  )
75
 
76
- # Load Hugging Face Model
77
- model_name = "NousResearch/Hermes-3-Llama-3.2-3B" # Replace with your desired model
78
- tokenizer = AutoTokenizer.from_pretrained(model_name)
79
- model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", trust_remote_code=True)
80
-
81
- # Ensure the model is on the GPU
82
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
83
- model.to(device)
84
 
85
- # Create Hugging Face Pipeline with the specified model and tokenizer
86
- hf_pipeline = pipeline("text-generation", model=model, tokenizer=tokenizer)
87
-
88
- # LangChain LLM using Hugging Face Pipeline
89
- llm = HuggingFacePipeline(pipeline=hf_pipeline)
90
 
91
  # Create prompt template
92
  template = """
@@ -117,7 +92,7 @@ Answer
117
 
118
  prompt = ChatPromptTemplate.from_template(template)
119
 
120
- # Create the RAG chain
121
  rag_chain = (
122
  {"context": retriever, "question": RunnablePassthrough()}
123
  | prompt
@@ -125,24 +100,19 @@ rag_chain = (
125
  | StrOutputParser()
126
  )
127
 
128
- # Define the Gradio function
129
- @spaces.GPU()
130
- def ask_question_gradio(question):
131
- result = ""
132
  for chunk in rag_chain.stream(question):
133
- result += chunk
134
- return result
135
-
136
- # Create the Gradio interface
137
- interface = gr.Interface(
138
- fn=ask_question_gradio,
139
- inputs="text",
140
- outputs="text",
141
- title="Mawared Expert Assistant",
142
- description="Ask questions about the Mawared HR System or any related topic using Chain-of-Thought (CoT) and RAG principles.",
143
- theme="compact",
144
- )
145
 
146
- # Launch Gradio app
147
  if __name__ == "__main__":
148
- interface.launch()
 
 
 
 
 
 
 
 
1
  from langchain_community.vectorstores import Qdrant
2
+ from langchain_groq import ChatGroq
3
  from langchain_huggingface import HuggingFaceEmbeddings
4
+ import os
5
+ from dotenv import load_dotenv
6
  from langchain.prompts import ChatPromptTemplate
7
  from langchain.schema.runnable import RunnablePassthrough
8
  from langchain.schema.output_parser import StrOutputParser
9
  from qdrant_client import QdrantClient, models
10
  from langchain_qdrant import Qdrant
11
+ from langchain_qdrant import QdrantVectorStore
12
+ from langchain_huggingface import ChatHuggingFace
13
+
14
 
15
  # Load environment variables
16
  load_dotenv()
17
 
18
+ os.environ["GROQ_API_KEY"] = os.getenv("GROQ_API")
19
+ HF_TOKEN = os.environ["HF_TOKEN"] = os.getenv("HF_TOKEN")
 
 
 
 
20
 
21
  # HuggingFace Embeddings
22
  embeddings = HuggingFaceEmbeddings(model_name="BAAI/bge-large-en-v1.5")
23
 
24
  # Qdrant Client Setup
25
  client = QdrantClient(
26
+ url=os.getenv("QDRANT_URL"),
27
+ api_key=os.getenv("QDRANT_API_KEY"),
28
+ prefer_grpc=True
29
  )
30
 
31
+ collection_name = "mawared"
 
 
 
 
 
 
 
 
 
32
 
33
  # Try to create collection, handle if it already exists
34
  try:
 
38
  size=768, # GTE-large embedding size
39
  distance=models.Distance.COSINE
40
  ),
41
+
42
  )
43
  print(f"Created new collection: {collection_name}")
44
  except Exception as e:
 
60
  search_kwargs={"k": 5}
61
  )
62
 
 
 
 
 
 
 
 
 
63
 
64
+ llm = ChatOpenAI(base_url="https://api-inference.huggingface.co/v1/", temperature=0 , api_key=HF_TOKEN , model="meta-llama/Llama-3.3-70B-Instruct")
 
 
 
 
65
 
66
  # Create prompt template
67
  template = """
 
92
 
93
  prompt = ChatPromptTemplate.from_template(template)
94
 
95
+ # Create the RAG chain using LCEL with prompt printing and streaming output
96
  rag_chain = (
97
  {"context": retriever, "question": RunnablePassthrough()}
98
  | prompt
 
100
  | StrOutputParser()
101
  )
102
 
103
+
104
+ # Function to ask questions
105
+ def ask_question(question):
106
+ print("Answer:\t", end=" ", flush=True)
107
  for chunk in rag_chain.stream(question):
108
+ print(chunk, end="", flush=True)
109
+ print("\n")
 
 
 
 
 
 
 
 
 
 
110
 
111
+ # Example usage
112
  if __name__ == "__main__":
113
+ while True:
114
+ user_question = input("\n \n \n Ask a question (or type 'quit' to exit): ")
115
+ if user_question.lower() == 'quit':
116
+ break
117
+ answer = ask_question(user_question)
118
+ # print("\nFull answer received.\n")