MarinaPlius commited on
Commit
06e57fe
·
verified ·
1 Parent(s): 19537bf

Update rag.py

Browse files
Files changed (1) hide show
  1. rag.py +30 -2
rag.py CHANGED
@@ -76,7 +76,7 @@ class RAG:
76
  print("Reranked documents")
77
  return documentos
78
 
79
- def predict(self, instruction, context, model_parameters):
80
 
81
  api_key = os.getenv("HF_TOKEN")
82
 
@@ -99,6 +99,34 @@ class RAG:
99
  response = requests.post(self.model_name, headers=headers, json=payload)
100
 
101
  return response.json()[0]["generated_text"].split("###")[-1][8:]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
102
 
103
  def beautiful_context(self, docs):
104
 
@@ -122,7 +150,7 @@ class RAG:
122
 
123
  del model_parameters["NUM_CHUNKS"]
124
 
125
- response = self.predict(prompt, text_context, model_parameters)
126
 
127
  if not response:
128
  return self.NO_ANSWER_MESSAGE
 
76
  print("Reranked documents")
77
  return documentos
78
 
79
+ def predict_dolly(self, instruction, context, model_parameters):
80
 
81
  api_key = os.getenv("HF_TOKEN")
82
 
 
99
  response = requests.post(self.model_name, headers=headers, json=payload)
100
 
101
  return response.json()[0]["generated_text"].split("###")[-1][8:]
102
+
103
+ def predict_completion(self, instruction, context, model_parameters):
104
+
105
+ client = OpenAI(
106
+ base_url=ENDPOINT_URL,
107
+ api_key=os.getenv("HF_TOKEN")
108
+ )
109
+
110
+
111
+ chat_completion = client.chat.completions.create(
112
+ model="tgi",
113
+ messages=[
114
+ {"role": "user", "content": instruction}
115
+ ],
116
+ temperature=model_parameters["temperature"],
117
+ max_tokens=model_parameters["max_new_tokens"],
118
+ stream=False,
119
+ stop=["<|im_end|>"],
120
+ extra_body = {
121
+ "presence_penalty": model_parameters["repetition_penalty"] - 2,
122
+ "do_sample": False
123
+ }
124
+ )
125
+
126
+ response = chat_completion.choices[0].message.content
127
+
128
+ return response
129
+
130
 
131
  def beautiful_context(self, docs):
132
 
 
150
 
151
  del model_parameters["NUM_CHUNKS"]
152
 
153
+ response = self.predict_completion(prompt, text_context, model_parameters)
154
 
155
  if not response:
156
  return self.NO_ANSWER_MESSAGE