In [1]:
outdir = '/media/Scratch/MLSS/'
filename = '../../Data/S8_pass_to_S9.csv'
batch_size=4
name='S9_gender'
In [2]:
import os # interact with file system
import sys # Needed to load settings file
import platform # identify OS type
import shutil
from glob import glob
from tqdm.autonotebook import tqdm

import csv # read and write csv files

import transformers
from transformers import BertTokenizer, BertForSequenceClassification
from transformers import pipeline
from transformers import AutoModelForSequenceClassification
from transformers import TrainingArguments, Trainer
import evaluate

from datasets import load_dataset, Features, Value, Dataset, load_metric
from sklearn.model_selection import train_test_split

import pandas as pd

import torch
import numpy as np

transformers.logging.set_verbosity_error()
/tmp/ipykernel_1863121/3703570026.py:6: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html
  from tqdm.autonotebook import tqdm
In [3]:
# Define tokenization for all models in this series

tokenizer = BertTokenizer.from_pretrained('yiyanghkust/finbert-pretrain')

def tokenize_function(observations):
    return tokenizer(observations["text"], padding="max_length", truncation=True, max_length=512)
In [4]:
def model_inference(encoded_dataset, name):
    # import the pytorch model
    model = AutoModelForSequenceClassification.from_pretrained(outdir + 'FinBERT_' + name + '/Best_model')
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model.to(device)
    
    logits = []
    for i in tqdm(range(0, len(encoded_dataset))):
        input_ids = encoded_dataset[i]['input_ids'].unsqueeze(0)  # unsqueeze(0) casts from 1D tensor to 2D tensor. Quirk of BERT.
        token_type_ids = encoded_dataset[i]['token_type_ids'].unsqueeze(0)
        attention_mask = encoded_dataset[i]['attention_mask'].unsqueeze(0)
        
        logits.append(model(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask).logits.cpu().detach().numpy().tolist()[0])
    
    # free space by deleting model and data from GPU
    del model

    # output the data
    cols = ['log' + str(i) for i in range(0, len(logits[0]))]
    return pd.DataFrame(np.array(logits), columns = cols)
In [5]:
df = pd.read_csv(outdir + name + '.csv.gz', encoding='utf-8')
dataset = Dataset.from_pandas(df)
In [6]:
encoded_dataset = dataset.map(tokenize_function)
Map: 100%|█████████████████████████████████████████████████████| 1195/1195 [00:11<00:00, 105.79 examples/s]
In [7]:
encoded_dataset
Out[7]:
Dataset({
    features: ['text', 'labels', 'input_ids', 'token_type_ids', 'attention_mask'],
    num_rows: 1195
})
In [8]:
encoded_dataset = encoded_dataset.remove_columns(['text', 'labels'])
In [9]:
# push encoded data to GPU
encoded_dataset.set_format("torch", device="cuda")
In [10]:
# pass data through the model
output = model_inference(encoded_dataset, name=name)
        
100%|██████████████████████████████████████████████████████████████████| 1195/1195 [00:20<00:00, 57.05it/s]
In [11]:
output = df.join(output)
encoded_dataset.set_format("torch", device="cpu")
In [12]:
output
Out[12]:
text labels log0 log1
0 Who can blame an investor for taking to the bo... 1 0.794654 -0.469129
1 Businesses are dumping office space at the fas... 0 0.804681 0.202452
2 WASHINGTON -- President George W. Bush's plan ... 0 1.144736 -0.799532
3 FRANKFURT -- The European Central Bank, while ... 1 0.791363 -0.383891
4 PARIS -- President Nicolas Sarkozy introduced ... 0 1.267021 -0.506859
... ... ... ... ...
1190 Should John McCain eke out a victory on Tuesda... 0 1.544418 -1.026227
1191 In times of economic crisis, with thoughts of ... 1 0.832721 -0.465457
1192 J.P. Morgan Chase & Co. launched an ambitious ... 1 1.524031 -0.449714
1193 The Morning Brief, a look at the day's biggest... 0 1.613019 -0.998035
1194 In the legal battle over chemical company Hunt... 0 1.347540 -0.431128

1195 rows × 4 columns

In [13]:
output.rename(columns={'log0': 'logit_male', 'log1':'logit_female'}, inplace=True)
output['i_female'] = output.apply(lambda x: 0 if x['logit_male'] > x['logit_female'] else 1, axis=1)
In [14]:
output[output['i_female'] == 1]
Out[14]:
text labels logit_male logit_female i_female
23 Merck & Co. Thursday said it halted developmen... 0 0.410868 0.453517 1
34 Amgen Inc. said a federal court granted the bi... 1 -0.144691 0.993399 1
36 Citing overwhelming economic pressures and war... 1 0.233885 0.375459 1
49 Options traders hopped on railroad companies T... 1 -0.181216 0.967081 1
54 The euro fell to a fresh one-year low against ... 1 0.145984 0.390424 1
... ... ... ... ... ...
1143 A federal appeals court has upheld a decision ... 0 0.115356 0.389280 1
1149 Walgreen Co. plans to cut costs, overhaul stor... 1 -0.073960 0.877992 1
1152 A combination of a treatment known as cognitiv... 0 0.296025 0.333333 1
1157 WASHINGTON -- The clouds around GlaxoSmithKlin... 1 -0.336796 0.735275 1
1158 Cigna Corp., citing competitive and economic p... 1 -0.064404 0.586696 1

78 rows × 5 columns

In [ ]:
 
In [ ]:
 
In [ ]: