ngrigg commited on
Commit
8270298
·
1 Parent(s): 6f15a2e

Update model loading to use AutoModelForCausalLM

Browse files
Files changed (2) hide show
  1. app.py +43 -2
  2. llama_models.py +2 -2
app.py CHANGED
@@ -1,4 +1,45 @@
1
  import streamlit as st
 
 
 
 
 
2
 
3
- x = st.slider('Select a value')
4
- st.write(x, 'squared is', x * x)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import streamlit as st
2
+ import pandas as pd
3
+ import asyncio
4
+ from llama_models import process_text
5
+ from dotenv import load_dotenv
6
+ import os
7
 
8
+ # Load environment variables from .env file
9
+ load_dotenv()
10
+
11
+ async def process_csv(file):
12
+ df = pd.read_csv(file, header=None) # Read the CSV file without a header
13
+ descriptions = df[0].tolist() # Access the first column directly
14
+ SAMPLE_SIZE = min(5, len(descriptions)) # Adjust sample size as needed
15
+ descriptions = descriptions[:SAMPLE_SIZE]
16
+
17
+ model_name = "instruction-pretrain/finance-Llama3-8B" # Ensure this is the correct model name
18
+
19
+ results = []
20
+ for desc in descriptions:
21
+ result = await process_text(model_name, desc)
22
+ results.append(result)
23
+
24
+ df['predictions'] = results
25
+ return df
26
+
27
+ st.title("Finance Model Deployment")
28
+
29
+ st.write("""
30
+ ### Upload a CSV file with company descriptions to extract key products, geographies, and important keywords:
31
+ """)
32
+
33
+ uploaded_file = st.file_uploader("Choose a CSV file", type="csv")
34
+
35
+ if uploaded_file is not None:
36
+ if st.button("Predict"):
37
+ with st.spinner("Processing..."):
38
+ df = asyncio.run(process_csv(uploaded_file))
39
+ st.write(df)
40
+ st.download_button(
41
+ label="Download Predictions as CSV",
42
+ data=df.to_csv(index=False).encode('utf-8'),
43
+ file_name='predictions.csv',
44
+ mime='text/csv'
45
+ )
llama_models.py CHANGED
@@ -1,12 +1,12 @@
1
  import os
2
- from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
3
  import aiohttp
4
 
5
  HUGGINGFACE_API_KEY = os.getenv("HUGGINGFACE_API_KEY")
6
 
7
  def load_model(model_name):
8
  tokenizer = AutoTokenizer.from_pretrained(model_name)
9
- model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
10
  return tokenizer, model
11
 
12
  async def process_text(model_name, text):
 
1
  import os
2
+ from transformers import AutoTokenizer, AutoModelForCausalLM
3
  import aiohttp
4
 
5
  HUGGINGFACE_API_KEY = os.getenv("HUGGINGFACE_API_KEY")
6
 
7
  def load_model(model_name):
8
  tokenizer = AutoTokenizer.from_pretrained(model_name)
9
+ model = AutoModelForCausalLM.from_pretrained(model_name) # Use AutoModelForCausalLM for Llama
10
  return tokenizer, model
11
 
12
  async def process_text(model_name, text):