File size: 1,676 Bytes
6f15a2e
c654f8e
6f15a2e
 
 
73e820c
 
6f15a2e
 
73e820c
 
 
 
c654f8e
73e820c
6f15a2e
 
 
 
 
 
 
73e820c
6f15a2e
 
 
73e820c
6f15a2e
73e820c
6f15a2e
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
import os
from transformers import AutoTokenizer, AutoModelForCausalLM  # Ensure correct model class
import aiohttp

HUGGINGFACE_API_KEY = os.getenv("HUGGINGFACE_API_KEY")
model = None
tokenizer = None

def load_model(model_name):
    global tokenizer, model
    if not tokenizer or not model:
        print("Loading model and tokenizer...")
        tokenizer = AutoTokenizer.from_pretrained(model_name)
        model = AutoModelForCausalLM.from_pretrained(model_name)  # Ensure correct model class
        print("Model and tokenizer loaded successfully.")
    return tokenizer, model

async def process_text(model_name, text):
    tokenizer, model = load_model(model_name)
    prompt = f"Given the following company description, extract key products, geographies, and important keywords:\n\n{text}\n\nProducts, geographies, and keywords:"

    async with aiohttp.ClientSession() as session:
        print(f"Sending request to model API for text: {text[:50]}...")
        async with session.post(f"https://api-inference.huggingface.co/models/{model_name}", 
                                headers={"Authorization": f"Bearer {HUGGINGFACE_API_KEY}"},
                                json={"inputs": prompt}) as response:
            print(f"Received response with status code: {response.status}")
            result = await response.json()
            print(f"Raw API response: {result}")
            if isinstance(result, list) and len(result) > 0:
                return result[0].get('generated_text', '').strip()
            elif isinstance(result, dict):
                return result.get('generated_text', '').strip()
            else:
                return str(result)