Kims12 commited on
Commit
70bd2c6
·
verified ·
1 Parent(s): 2b7334e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +25 -21
app.py CHANGED
@@ -4,6 +4,8 @@ import openai
4
  import anthropic
5
  import os
6
  from typing import Optional
 
 
7
 
8
  #############################
9
  # [기본코드] - 수정/삭제 불가
@@ -171,25 +173,27 @@ def respond_claude_qna(
171
  return f"예상치 못한 오류가 발생했습니다: {str(e)}"
172
 
173
  #############################
174
- # [추가코드] - Llama-3.3-70B-Instruct / Llama-3.2-3B-Instruct 적용
175
  #############################
176
 
177
  def get_llama_client(model_choice: str):
178
  """
179
- 선택된 Llama 모델에 맞춰 InferenceClient 생성.
180
- 토큰은 환경 변수에서 가져옴.
181
  """
182
- hf_token = os.getenv("HF_TOKEN")
183
- if not hf_token:
184
- raise ValueError("HuggingFace API 토큰이 필요합니다.")
185
-
186
  if model_choice == "Llama-3.3-70B-Instruct":
187
- model_id = "Llama-3.3-70B-Instruct"
188
  elif model_choice == "Llama-3.2-3B-Instruct":
189
- model_id = "Llama-3.2-3B-Instruct"
190
  else:
191
  raise ValueError("유효하지 않은 모델 선택입니다.")
192
- return InferenceClient(model_id, token=hf_token)
 
 
 
 
 
 
 
193
 
194
  def respond_llama_qna(
195
  question: str,
@@ -200,27 +204,27 @@ def respond_llama_qna(
200
  model_choice: str
201
  ):
202
  """
203
- 선택된 Llama 모델을 이용해 한 번의 질문(question)에 대한 답변을 반환하는 함수.
 
204
  """
205
  try:
206
- client = get_llama_client(model_choice)
207
  except ValueError as e:
208
  return f"오류: {str(e)}"
209
 
210
- messages = [
211
- {"role": "system", "content": system_message},
212
- {"role": "user", "content": question}
213
- ]
214
 
215
  try:
216
- response_full = client.chat_completion(
217
- messages,
218
- max_tokens=max_tokens,
219
  temperature=temperature,
220
  top_p=top_p,
221
  )
222
- assistant_message = response_full.choices[0].message.content
223
- return assistant_message
 
224
  except Exception as e:
225
  return f"오류가 발생했습니다: {str(e)}"
226
 
 
4
  import anthropic
5
  import os
6
  from typing import Optional
7
+ import transformers
8
+ import torch
9
 
10
  #############################
11
  # [기본코드] - 수정/삭제 불가
 
173
  return f"예상치 못한 오류가 발생했습니다: {str(e)}"
174
 
175
  #############################
176
+ # [추가코드] - Llama-3.3-70B-Instruct / Llama-3.2-3B-Instruct 적용 (transformers.pipeline 방식)
177
  #############################
178
 
179
  def get_llama_client(model_choice: str):
180
  """
181
+ 선택된 Llama 모델에 맞춰 transformers의 text-generation 파이프라인을 생성.
 
182
  """
 
 
 
 
183
  if model_choice == "Llama-3.3-70B-Instruct":
184
+ model_id = "meta-llama/Llama-3.3-70B-Instruct"
185
  elif model_choice == "Llama-3.2-3B-Instruct":
186
+ model_id = "meta-llama/Llama-3.2-3B-Instruct"
187
  else:
188
  raise ValueError("유효하지 않은 모델 선택입니다.")
189
+
190
+ pipeline_llama = transformers.pipeline(
191
+ "text-generation",
192
+ model=model_id,
193
+ model_kwargs={"torch_dtype": torch.bfloat16},
194
+ device_map="auto",
195
+ )
196
+ return pipeline_llama
197
 
198
  def respond_llama_qna(
199
  question: str,
 
204
  model_choice: str
205
  ):
206
  """
207
+ 선택된 Llama 모델을 이용해 한 번의 질문(question)에 대한 답변을 transformers 파이프라인으로 반환하는 함수.
208
+ system_message와 question을 하나의 프롬프트로 결합하여 생성합니다.
209
  """
210
  try:
211
+ pipeline_llama = get_llama_client(model_choice)
212
  except ValueError as e:
213
  return f"오류: {str(e)}"
214
 
215
+ # system_message와 question을 연결하여 프롬프트 생성
216
+ prompt = system_message.strip() + "\n" + question.strip()
 
 
217
 
218
  try:
219
+ outputs = pipeline_llama(
220
+ prompt,
221
+ max_new_tokens=max_tokens,
222
  temperature=temperature,
223
  top_p=top_p,
224
  )
225
+ # 생성된 텍스트를 추출 (전체 프롬프트 이후의 텍스트만 반환할 수도 있음)
226
+ generated_text = outputs[0]["generated_text"]
227
+ return generated_text
228
  except Exception as e:
229
  return f"오류가 발생했습니다: {str(e)}"
230