AminFaraji commited on
Commit
26319c8
·
verified ·
1 Parent(s): 33e1e9e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -23
app.py CHANGED
@@ -91,19 +91,8 @@ db = Chroma(persist_directory=CHROMA_PATH, embedding_function=embeddings)
91
 
92
 
93
 
94
- MODEL_NAME = "gpt2"
95
-
96
- model = AutoModelForCausalLM.from_pretrained(
97
- "gpt2",
98
- device_map="auto",
99
- low_cpu_mem_usage=True,
100
- torch_dtype=torch.float16 # Use float16 to reduce memory usage
101
- )
102
- model = model.eval()
103
-
104
- tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
105
- print(f"Model device: {model.device}")
106
-
107
 
108
  generation_config = model.generation_config
109
  generation_config.temperature = 0
@@ -228,18 +217,17 @@ def get_llama_response(message):
228
  input_text = query_text
229
 
230
  # Tokenize the input text
231
- inputs = tokenizer(input_text, return_tensors="pt").to("cuda")
232
-
233
- # Generate text
234
- with torch.no_grad():
235
- outputs = model.generate(inputs.input_ids, max_length=50)
236
-
237
- # Decode the generated text
238
- generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
239
- return(generated_text)
240
 
241
  import gradio as gr
242
 
243
  #gr.ChatInterface(get_llama_response).launch()
244
  iface = gr.Interface(fn=get_llama_response, inputs="text", outputs="text")
245
- iface.launch()
 
 
 
 
 
91
 
92
 
93
 
94
+ tokenizer = AutoTokenizer.from_pretrained("gpt2")
95
+ model = AutoModelForCausalLM.from_pretrained("gpt2")
 
 
 
 
 
 
 
 
 
 
 
96
 
97
  generation_config = model.generation_config
98
  generation_config.temperature = 0
 
217
  input_text = query_text
218
 
219
  # Tokenize the input text
220
+ inputs = tokenizer(input_text, return_tensors="pt")
221
+ outputs = model.generate(inputs.input_ids, max_length=50)
222
+ response = tokenizer.decode(outputs[0], skip_special_tokens=True)
223
+ return(response)
 
 
 
 
 
224
 
225
  import gradio as gr
226
 
227
  #gr.ChatInterface(get_llama_response).launch()
228
  iface = gr.Interface(fn=get_llama_response, inputs="text", outputs="text")
229
+ iface.launch()
230
+
231
+
232
+
233
+