acecalisto3 commited on
Commit
0a3c7ba
·
verified ·
1 Parent(s): b18fe27

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +19 -7
app.py CHANGED
@@ -1,4 +1,3 @@
1
- import streamlit as st
2
  import subprocess
3
  import os
4
  from io import StringIO
@@ -6,6 +5,19 @@ import sys
6
  import black
7
  from pylint import lint
8
  from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
  # Global state to manage communication between Tool Box and Workspace Chat App
11
  if 'chat_history' not in st.session_state:
@@ -91,12 +103,12 @@ def chat_interface_with_agent(input_text, agent_name):
91
  combined_input = f"{agent_prompt}\n\nUser: {input_text}\nAgent:"
92
 
93
  # Truncate input text to avoid exceeding the model's maximum length
94
- max_input_length = 900
95
  input_ids = tokenizer.encode(combined_input, return_tensors="pt")
96
  if input_ids.shape[1] > max_input_length:
97
  input_ids = input_ids[:, :max_input_length]
98
 
99
- outputs = model.generate(input_ids, max_length=1024, do_sample=True)
100
  response = tokenizer.decode(outputs[0], skip_special_tokens=True)
101
  return response
102
 
@@ -127,12 +139,12 @@ def chat_interface(input_text):
127
 
128
 
129
  # Truncate input text to avoid exceeding the model's maximum length
130
- max_input_length = 900
131
  input_ids = tokenizer.encode(input_text, return_tensors="pt")
132
  if input_ids.shape[1] > max_input_length:
133
  input_ids = input_ids[:, :max_input_length]
134
 
135
- outputs = model.generate(input_ids, max_length=1024, do_sample=True)
136
  response = tokenizer.decode(outputs[0], skip_special_tokens=True)
137
  return response
138
 
@@ -264,7 +276,7 @@ def summarize_text(text):
264
  return f"Error loading model: {e}"
265
 
266
  # Truncate input text to avoid exceeding the model's maximum length
267
- max_input_length = 1024
268
  inputs = text
269
  if len(text) > max_input_length:
270
  inputs = text[:max_input_length]
@@ -346,7 +358,7 @@ def generate_code(idea):
346
  input_ids = tokenizer.encode(input_text, return_tensors="pt")
347
  output_sequences = model.generate(
348
  input_ids=input_ids,
349
- max_length=1024,
350
  num_return_sequences=1,
351
  no_repeat_ngram_size=2,
352
  early_stopping=True,
 
 
1
  import subprocess
2
  import os
3
  from io import StringIO
 
5
  import black
6
  from pylint import lint
7
  from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
8
+ # Initialize chat_history in the session state
9
+
10
+ if 'chat_history' not in st.session_state:
11
+ st.session_state['chat_history'] = []
12
+
13
+ # Access and update chat_history
14
+ chat_history = st.session_state['chat_history']
15
+ chat_history.append("New message")
16
+
17
+ # Display chat history
18
+ st.write("Chat History:")
19
+ for message in chat_history:
20
+ st.write(message)
21
 
22
  # Global state to manage communication between Tool Box and Workspace Chat App
23
  if 'chat_history' not in st.session_state:
 
103
  combined_input = f"{agent_prompt}\n\nUser: {input_text}\nAgent:"
104
 
105
  # Truncate input text to avoid exceeding the model's maximum length
106
+ max_input_length = max_input_length
107
  input_ids = tokenizer.encode(combined_input, return_tensors="pt")
108
  if input_ids.shape[1] > max_input_length:
109
  input_ids = input_ids[:, :max_input_length]
110
 
111
+ outputs = model.generate(input_ids, max_length=max_input_length, do_sample=True)
112
  response = tokenizer.decode(outputs[0], skip_special_tokens=True)
113
  return response
114
 
 
139
 
140
 
141
  # Truncate input text to avoid exceeding the model's maximum length
142
+ max_input_length = max_input_length
143
  input_ids = tokenizer.encode(input_text, return_tensors="pt")
144
  if input_ids.shape[1] > max_input_length:
145
  input_ids = input_ids[:, :max_input_length]
146
 
147
+ outputs = model.generate(input_ids, max_length=max, do_sample=True)
148
  response = tokenizer.decode(outputs[0], skip_special_tokens=True)
149
  return response
150
 
 
276
  return f"Error loading model: {e}"
277
 
278
  # Truncate input text to avoid exceeding the model's maximum length
279
+ max_input_length = max_input_length
280
  inputs = text
281
  if len(text) > max_input_length:
282
  inputs = text[:max_input_length]
 
358
  input_ids = tokenizer.encode(input_text, return_tensors="pt")
359
  output_sequences = model.generate(
360
  input_ids=input_ids,
361
+ max_length=max_length,
362
  num_return_sequences=1,
363
  no_repeat_ngram_size=2,
364
  early_stopping=True,